@@ -89,7 +89,7 @@ impl<F: FloatTrait, const DIM: usize> FelsensteinTree<F, DIM> {
8989 /// If the length of `s` and `sqrt_pi` is 1, it will use a different code path that is optimized for this case and assumes that they are the same for all columns.
9090 ///
9191 /// Only the upper diagonal part of `s` is used. The gradients will only be populated in the upper diagonal and the lower diagonal will be filled with zeros.
92- ///
92+ ///
9393 /// This functions assumes you have already called `bind_leaf_log_p` to bind the log probabilities of the leaves.
9494 pub fn calculate_gradients (
9595 & mut self ,
@@ -117,9 +117,42 @@ impl<F: FloatTrait, const DIM: usize> FelsensteinTree<F, DIM> {
117117 & sqrt_pi[ 0 ] ,
118118 tree,
119119 d_trans_matrix,
120+ false ,
120121 ) ;
121122 }
122- calculate_column_parallel ( & mut self . log_p , s, sqrt_pi, tree)
123+ calculate_column_parallel ( & mut self . log_p , s, sqrt_pi, tree, false )
124+ }
125+
126+ /// Same as `calculate_gradients`, but only calculates the log likelihoods for each side in the alignment.
127+ pub fn calculate_likelihoods (
128+ & mut self ,
129+ s : & [ na:: SMatrix < F , DIM , DIM > ] ,
130+ sqrt_pi : & [ na:: SVector < F , DIM > ] ,
131+ ) -> Vec < F > {
132+ let tree = tree:: Tree :: new ( & self . parents , & self . distances , self . num_leaves ) ;
133+ // Zero out internal nodes in log_p
134+ for log_p in & mut self . log_p {
135+ log_p. iter_mut ( ) . skip ( self . num_leaves ) . for_each ( |p| {
136+ * p = na:: SVector :: < F , DIM > :: zeros ( ) ;
137+ } ) ;
138+ }
139+
140+ let result = if s. len ( ) == 1 && sqrt_pi. len ( ) == 1 {
141+ let mut d_trans_matrix = Vec :: new ( ) ; // not used in this case
142+
143+ calculate_column_parallel_single_S (
144+ & mut self . log_p ,
145+ & s[ 0 ] ,
146+ & sqrt_pi[ 0 ] ,
147+ tree,
148+ & mut d_trans_matrix,
149+ true ,
150+ )
151+ } else {
152+ calculate_column_parallel ( & mut self . log_p , s, sqrt_pi, tree, true )
153+ } ;
154+
155+ return result. log_likelihood ;
123156 }
124157
125158 /// Same as `calculate_gradients`, but it takes also an array of the log_probabilities of the leaves.
@@ -131,56 +164,39 @@ impl<F: FloatTrait, const DIM: usize> FelsensteinTree<F, DIM> {
131164 log_p : & mut [ & mut [ na:: SVector < F , DIM > ] ] ,
132165 ) -> FelsensteinResult < F , DIM > {
133166 let tree = tree:: Tree :: new ( & self . parents , & self . distances , self . num_leaves ) ;
134- calculate_column_parallel (
135- log_p,
136- s,
137- sqrt_pi,
138- tree,
139- )
167+ calculate_column_parallel ( log_p, s, sqrt_pi, tree, false )
140168 }
141169
142170 /// This function calculates the gradients for a single side in the alignment.
143171 /// This can be useful if you want to control the parallelization yourself or if you want to calculate the gradients for a single side.
144- ///
172+ ///
145173 /// log_p is expected to have enough space to hold the log probabilities for all nodes
146174 pub fn calculate_gradients_single_side (
147175 & self ,
148176 s : na:: SMatrixView < F , DIM , DIM > ,
149177 sqrt_pi : na:: SVectorView < F , DIM > ,
150- log_p : & mut [ na:: SVector < F , DIM > ]
178+ log_p : & mut [ na:: SVector < F , DIM > ] ,
151179 ) -> SingleSideResult < F , DIM > {
152180 let tree = tree:: Tree :: new ( & self . parents , & self . distances , self . num_leaves ) ;
153181 // zero out internal nodes in log_p
154182 log_p[ self . num_leaves ..] . iter_mut ( ) . for_each ( |p| {
155183 * p = na:: SVector :: < F , DIM > :: zeros ( ) ;
156184 } ) ;
157- calculate_column (
158- log_p,
159- s. as_view ( ) ,
160- sqrt_pi. as_view ( ) ,
161- tree,
162- false ,
163- )
185+ calculate_column ( log_p, s. as_view ( ) , sqrt_pi. as_view ( ) , tree, false )
164186 }
165187
166188 pub fn calculate_likelihood_single_side (
167189 & self ,
168190 s : na:: SMatrixView < F , DIM , DIM > ,
169191 sqrt_pi : na:: SVectorView < F , DIM > ,
170- log_p : & mut [ na:: SVector < F , DIM > ]
192+ log_p : & mut [ na:: SVector < F , DIM > ] ,
171193 ) -> F {
172194 let tree = tree:: Tree :: new ( & self . parents , & self . distances , self . num_leaves ) ;
173195 // zero out internal nodes in log_p
174196 log_p[ self . num_leaves ..] . iter_mut ( ) . for_each ( |p| {
175197 * p = na:: SVector :: < F , DIM > :: zeros ( ) ;
176198 } ) ;
177- let result = calculate_column (
178- log_p,
179- s. as_view ( ) ,
180- sqrt_pi. as_view ( ) ,
181- tree,
182- true ,
183- ) ;
199+ let result = calculate_column ( log_p, s. as_view ( ) , sqrt_pi. as_view ( ) , tree, true ) ;
184200 result. log_likelihood
185201 }
186202}
0 commit comments