Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions phylo_grad/src/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,15 @@ fn d_broadcast_vjp<F: FloatTrait, const DIM: usize>(

/// Main part of the backward where we go back through one Felsenstein step, it takes the cotangent of the parent log_p and calculates the cotangent of the child log_p and the parameters
/// forward_exp_save will be the output cotangent for the log_transition matrix
/// This is the backward function (vjp) for forward_node
pub fn d_log_transition_child_input_vjp<F: FloatTrait, const DIM: usize>(
cotangent_vector: & na::SVector<F, DIM>,
forward_exp_save: &mut na::SMatrix<F, DIM, DIM>,
forward_sum_save: &mut na::SVector<F, DIM>,
compute_grad_log_p: bool,
) -> Option<na::SVector<F, DIM>> {

// This function will change the most

let forward_exp_save_data = &mut forward_exp_save.data.0;

Expand Down Expand Up @@ -195,6 +198,8 @@ pub fn d_child_input_param<F: FloatTrait, const DIM: usize>(
forward_sum_save,
compute_grad_log_p
);

// likley to be removed
d_ln_vjp(forward_exp_save, &forward.matrix_exp_recip);

d_expm_vjp(forward_exp_save, distance, param, &forward.exp_t_lambda);
Expand Down
17 changes: 16 additions & 1 deletion phylo_grad/src/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,31 @@ impl<F, const DIM: usize> ForwardData<F, DIM> {
}
}

/// Data precomputed for each edge. Depends only on the Q matrix and the edge length
#[derive(Debug)]
pub struct LogTransitionForwardData<F, const DIM: usize> {
/// 1 / e^(t Q) (matrix exponential) and then element wise reciprical
pub matrix_exp_recip: na::SMatrix<F, DIM, DIM>,
/// log(matrix_exp) transposed
pub log_transition_T: na::SMatrix<F, DIM, DIM>,
/// exp(t * lambda_i) for the DIM many eigenvalues of Q
pub exp_t_lambda: na::SVector<F, DIM>,
}

/// Precomputed values from the model (S and sqrt_pi)
#[derive(Debug)]
pub struct ParamPrecomp<F, const DIM: usize> {
/// S
pub symmetric_matrix: na::SMatrix<F, DIM, DIM>,
/// sqrt_pi
pub sqrt_pi: na::SVector<F, DIM>,
/// 1/sqrt_pi
pub sqrt_pi_recip: na::SVector<F, DIM>,
/// Eigenvalues of S
pub eigenvalues: na::SVector<F, DIM>,
/// A in the paper
pub V_pi: na::SMatrix<F, DIM, DIM>,
/// A^-1 in the paper
pub V_pi_inv: na::SMatrix<F, DIM, DIM>,
}

Expand Down Expand Up @@ -171,6 +182,9 @@ pub fn forward_data_precompute_param<F: FloatTrait, const DIM: usize>(
}

/// adds the log_p of the children to the log_p of the parent
/// Main part of the Felsenstein in Forward
/// log_p are the partial log likelihoods, they start with the leave nodes initialized. This function takes 2 computed log_p vectors
/// and writes the compbined result in the parent log_p vector
pub fn forward_node<F: FloatTrait, const DIM: usize>(
child: usize,
parent: usize,
Expand All @@ -181,6 +195,7 @@ pub fn forward_node<F: FloatTrait, const DIM: usize>(
let logsumexp_exp_save = &mut forward_data_save.logsumexp_exp_save[child].data.0;
let logsumexp_sum_save = forward_data_save.logsumexp_sum_save[child].as_mut_slice();
/* log_p[parent]_a = logsumexp_b(log_p[child](b) + log_transition(rate_matrix, distance)(a, b) ) */
// In linspace log_p[parent]_a = sum_b (log_p[child](b) * transiton(rate_matrix, distance)(a,b) )
for a in 0..DIM {
let row_a = forward_data.log_transition[child].log_transition_T.column(a);
let tmp = log_p[child] + row_a;
Expand All @@ -189,4 +204,4 @@ pub fn forward_node<F: FloatTrait, const DIM: usize>(
&tmp.data.0), &mut logsumexp_exp_save[a], &mut logsumexp_sum_save[a]);
}
}
}
}
Loading