Skip to content

Commit 0c10edb

Browse files
calculate likelihoods also for the batched case
1 parent ceec351 commit 0c10edb

2 files changed

Lines changed: 59 additions & 27 deletions

File tree

phylo_grad/src/lib.rs

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ impl<F: FloatTrait, const DIM: usize> FelsensteinTree<F, DIM> {
8989
/// If the length of `s` and `sqrt_pi` is 1, it will use a different code path that is optimized for this case and assumes that they are the same for all columns.
9090
///
9191
/// Only the upper diagonal part of `s` is used. The gradients will only be populated in the upper diagonal and the lower diagonal will be filled with zeros.
92-
///
92+
///
9393
/// This functions assumes you have already called `bind_leaf_log_p` to bind the log probabilities of the leaves.
9494
pub fn calculate_gradients(
9595
&mut self,
@@ -117,9 +117,42 @@ impl<F: FloatTrait, const DIM: usize> FelsensteinTree<F, DIM> {
117117
&sqrt_pi[0],
118118
tree,
119119
d_trans_matrix,
120+
false,
120121
);
121122
}
122-
calculate_column_parallel(&mut self.log_p, s, sqrt_pi, tree)
123+
calculate_column_parallel(&mut self.log_p, s, sqrt_pi, tree, false)
124+
}
125+
126+
/// Same as `calculate_gradients`, but only calculates the log likelihoods for each side in the alignment.
127+
pub fn calculate_likelihoods(
128+
&mut self,
129+
s: &[na::SMatrix<F, DIM, DIM>],
130+
sqrt_pi: &[na::SVector<F, DIM>],
131+
) -> Vec<F> {
132+
let tree = tree::Tree::new(&self.parents, &self.distances, self.num_leaves);
133+
// Zero out internal nodes in log_p
134+
for log_p in &mut self.log_p {
135+
log_p.iter_mut().skip(self.num_leaves).for_each(|p| {
136+
*p = na::SVector::<F, DIM>::zeros();
137+
});
138+
}
139+
140+
let result = if s.len() == 1 && sqrt_pi.len() == 1 {
141+
let mut d_trans_matrix = Vec::new(); // not used in this case
142+
143+
calculate_column_parallel_single_S(
144+
&mut self.log_p,
145+
&s[0],
146+
&sqrt_pi[0],
147+
tree,
148+
&mut d_trans_matrix,
149+
true,
150+
)
151+
} else {
152+
calculate_column_parallel(&mut self.log_p, s, sqrt_pi, tree, true)
153+
};
154+
155+
return result.log_likelihood;
123156
}
124157

125158
/// Same as `calculate_gradients`, but it takes also an array of the log_probabilities of the leaves.
@@ -131,56 +164,39 @@ impl<F: FloatTrait, const DIM: usize> FelsensteinTree<F, DIM> {
131164
log_p: &mut [&mut [na::SVector<F, DIM>]],
132165
) -> FelsensteinResult<F, DIM> {
133166
let tree = tree::Tree::new(&self.parents, &self.distances, self.num_leaves);
134-
calculate_column_parallel(
135-
log_p,
136-
s,
137-
sqrt_pi,
138-
tree,
139-
)
167+
calculate_column_parallel(log_p, s, sqrt_pi, tree, false)
140168
}
141169

142170
/// This function calculates the gradients for a single side in the alignment.
143171
/// This can be useful if you want to control the parallelization yourself or if you want to calculate the gradients for a single side.
144-
///
172+
///
145173
/// log_p is expected to have enough space to hold the log probabilities for all nodes
146174
pub fn calculate_gradients_single_side(
147175
&self,
148176
s: na::SMatrixView<F, DIM, DIM>,
149177
sqrt_pi: na::SVectorView<F, DIM>,
150-
log_p: &mut [na::SVector<F, DIM>]
178+
log_p: &mut [na::SVector<F, DIM>],
151179
) -> SingleSideResult<F, DIM> {
152180
let tree = tree::Tree::new(&self.parents, &self.distances, self.num_leaves);
153181
// zero out internal nodes in log_p
154182
log_p[self.num_leaves..].iter_mut().for_each(|p| {
155183
*p = na::SVector::<F, DIM>::zeros();
156184
});
157-
calculate_column(
158-
log_p,
159-
s.as_view(),
160-
sqrt_pi.as_view(),
161-
tree,
162-
false,
163-
)
185+
calculate_column(log_p, s.as_view(), sqrt_pi.as_view(), tree, false)
164186
}
165187

166188
pub fn calculate_likelihood_single_side(
167189
&self,
168190
s: na::SMatrixView<F, DIM, DIM>,
169191
sqrt_pi: na::SVectorView<F, DIM>,
170-
log_p: &mut [na::SVector<F, DIM>]
192+
log_p: &mut [na::SVector<F, DIM>],
171193
) -> F {
172194
let tree = tree::Tree::new(&self.parents, &self.distances, self.num_leaves);
173195
// zero out internal nodes in log_p
174196
log_p[self.num_leaves..].iter_mut().for_each(|p| {
175197
*p = na::SVector::<F, DIM>::zeros();
176198
});
177-
let result = calculate_column(
178-
log_p,
179-
s.as_view(),
180-
sqrt_pi.as_view(),
181-
tree,
182-
true,
183-
);
199+
let result = calculate_column(log_p, s.as_view(), sqrt_pi.as_view(), tree, true);
184200
result.log_likelihood
185201
}
186202
}

phylo_grad/src/run.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ pub fn calculate_column_parallel<
178178
S: &[na::SMatrix<F, DIM, DIM>],
179179
sqrt_pi: &[na::SVector<F, DIM>],
180180
tree: Tree<F>,
181+
only_likelihood: bool,
181182
) -> FelsensteinResult<F, DIM> {
182183
let col_results = (leaf_log_p, S, sqrt_pi)
183184
.into_par_iter()
@@ -187,7 +188,7 @@ pub fn calculate_column_parallel<
187188
S.as_view(),
188189
sqrt_pi.as_view(),
189190
tree.clone(),
190-
false,
191+
only_likelihood,
191192
) // The clone is shallow, Tree is cheap to clone
192193
})
193194
.collect::<Vec<_>>();
@@ -217,6 +218,7 @@ pub fn calculate_column_parallel_single_S<F: FloatTrait, const DIM: usize>(
217218
sqrt_pi: &na::SVector<F, DIM>,
218219
tree: Tree<F>,
219220
d_trans_matrix: &mut [Vec<na::SMatrix<F, DIM, DIM>>],
221+
only_likelihood: bool
220222
) -> FelsensteinResult<F, DIM> {
221223
let L = leaf_log_p.len();
222224

@@ -240,12 +242,20 @@ pub fn calculate_column_parallel_single_S<F: FloatTrait, const DIM: usize>(
240242
.into_par_iter()
241243
.zip(d_trans_matrix.par_iter_mut())
242244
.map(|(leaf_log_p, d_trans)| {
243-
cacluate_column_single_S(leaf_log_p, &param, &forward_data, tree.clone(), d_trans)
245+
cacluate_column_single_S(leaf_log_p, &param, &forward_data, tree.clone(), d_trans, only_likelihood)
244246
})
245247
.collect::<Vec<_>>();
246248

247249
let log_likelihood = result.iter().map(|r| r.0).collect::<Vec<_>>();
248250

251+
if only_likelihood {
252+
return FelsensteinResult::<F, DIM> {
253+
log_likelihood,
254+
grad_s: vec![na::SMatrix::<F, DIM, DIM>::zeros()],
255+
grad_sqrt_pi: vec![na::SVector::<F, DIM>::zeros()],
256+
};
257+
}
258+
249259
let sum_d_log_prior = result.iter().map(|r| r.1).sum::<na::SVector<F, DIM>>();
250260

251261
// We need to skip the root edge, as it does not exist and it will always be the last edge
@@ -289,12 +299,14 @@ fn d_rate_matrix_per_edge<F: FloatTrait, const DIM: usize>(
289299
sum_d_log_trans
290300
}
291301

302+
/// In case of only_likelihood=true, d_trans_matrix will not be used
292303
fn cacluate_column_single_S<F: FloatTrait, const DIM: usize>(
293304
leaf_log_p: &mut [na::SVector<F, DIM>],
294305
param: &ParamPrecomp<F, DIM>,
295306
forward_data: &ForwardData<F, DIM>,
296307
tree: Tree<F>,
297308
d_trans_matrix: &mut [na::SMatrix<F, DIM, DIM>],
309+
only_likelihood: bool
298310
) -> (F, na::SVector<F, DIM>) {
299311
forward_column(leaf_log_p, tree.parents, forward_data);
300312
let log_p = leaf_log_p;
@@ -304,6 +316,10 @@ fn cacluate_column_single_S<F: FloatTrait, const DIM: usize>(
304316

305317
let (log_likelihood, grad_log_p_likelihood) =
306318
final_likelihood(log_p_root.as_view(), log_p_prior.as_view());
319+
320+
if only_likelihood {
321+
return (log_likelihood, na::SVector::<F, DIM>::zeros());
322+
}
307323
let d_log_prior = grad_log_p_likelihood;
308324
let d_log_p_root = grad_log_p_likelihood;
309325

0 commit comments

Comments
 (0)