Skip to content

Commit 2e780f3

Browse files
logsumexp resuse implementation
1 parent 00adcf2 commit 2e780f3

4 files changed

Lines changed: 177 additions & 77 deletions

File tree

phylo_grad/src/backward.rs

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -141,69 +141,63 @@ pub fn d_param<F: FloatTrait, const DIM: usize>(
141141
(grad_s, grad_sqrt_pi)
142142
}
143143

144-
fn child_input_forward_data<F: FloatTrait, const DIM: usize>(
145-
log_p: na::SVectorView<F, DIM>,
146-
log_transition_T: &na::SMatrix<F, DIM, DIM>,
147-
output: &mut na::SMatrix<F, DIM, DIM>,
148-
) {
149-
/* result = log_p[None, :] + log_transition */
150-
for i in 0..DIM {
151-
for j in 0..DIM {
152-
output[(i, j)] = log_p[j] + log_transition_T[(j, i)];
153-
}
154-
}
155-
}
156-
157144
fn d_broadcast_vjp<F: FloatTrait, const DIM: usize>(
158145
cotangent_vector: na::SMatrixView<F, DIM, DIM>,
159146
) -> na::SVector<F, DIM> {
160147
/* sum(cotangent_vector, dim=1) */
161148
na::SVector::<F, DIM>::from_iterator(cotangent_vector.column_iter().map(|col| col.sum()))
162149
}
163150

151+
/// Main part of the backward where we go back through one Felsenstein step, it takes the cotangent of the parent log_p and calculates the cotangent of the child log_p and the parameters
152+
/// forward_exp_save will be the output cojangend for the log_transition matrix
164153
pub fn d_log_transition_child_input_vjp<F: FloatTrait, const DIM: usize>(
165154
cotangent_vector: na::SVectorView<F, DIM>,
166-
log_p: na::SVectorView<F, DIM>,
167-
forward: &LogTransitionForwardData<F, DIM>,
155+
forward_exp_save: &mut na::SMatrix<F, DIM, DIM>,
156+
forward_sum_save: &mut na::SVector<F, DIM>,
168157
compute_grad_log_p: bool,
169-
output: &mut na::SMatrix<F, DIM, DIM>,
170158
) -> Option<na::SVector<F, DIM>> {
171-
child_input_forward_data(log_p, &forward.log_transition_T, output);
172-
173-
/* d_lse */
174-
for mut row in output.row_iter_mut() {
175-
row.copy_from(&softmax(&row.transpose()).transpose());
159+
160+
let forward_exp_save_data = &mut forward_exp_save.data.0;
161+
162+
// Does the softmax, which is the gradient of the logsumexp
163+
for a in 0..DIM {
164+
for b in 0..DIM {
165+
forward_exp_save_data[a][b] /= forward_sum_save[b];
166+
}
176167
}
177-
diag_times_assign(output.as_view_mut(), cotangent_vector.iter().copied());
168+
169+
forward_exp_save.transpose_mut();
170+
171+
diag_times_assign(forward_exp_save.as_view_mut(), cotangent_vector.iter().copied());
178172

179173
let grad_log_p = if compute_grad_log_p {
180-
Some(d_broadcast_vjp(output.as_view()))
174+
Some(d_broadcast_vjp(forward_exp_save.as_view()))
181175
} else {
182176
None
183177
};
184178

185179
grad_log_p
186180
}
187181

182+
/// forward_exp_save will be the output cotangent for Q
188183
pub fn d_child_input_param<F: FloatTrait, const DIM: usize>(
189184
cotangent_vector: na::SVectorView<F, DIM>,
190185
distance: F,
191186
param: &ParamPrecomp<F, DIM>,
192-
log_p: na::SVectorView<F, DIM>,
193187
forward: &LogTransitionForwardData<F, DIM>,
188+
forward_exp_save: &mut na::SMatrix<F, DIM, DIM>,
189+
forward_sum_save: &mut na::SVector<F, DIM>,
194190
compute_grad_log_p: bool,
195-
output: &mut na::SMatrix<F, DIM, DIM>,
196191
) -> Option<na::SVector<F, DIM>> {
197192
let grad_log_p = d_log_transition_child_input_vjp(
198193
cotangent_vector,
199-
log_p,
200-
forward,
201-
compute_grad_log_p,
202-
output,
194+
forward_exp_save,
195+
forward_sum_save,
196+
compute_grad_log_p
203197
);
204-
d_ln_vjp(output, &forward.matrix_exp);
198+
d_ln_vjp(forward_exp_save, &forward.matrix_exp);
205199

206-
d_expm_vjp(output, distance, param, &forward.exp_t_lambda);
200+
d_expm_vjp(forward_exp_save, distance, param, &forward.exp_t_lambda);
207201

208202
grad_log_p
209203
}

phylo_grad/src/data_types.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ where
2626
fn scalar_exp(self) -> Self;
2727
fn vec_exp<const N: usize>(x: &mut [Self; N]);
2828
fn vec_logsumexp<const N: usize>(x: &[Self; N]) -> Self;
29+
// Saves exp(x - max) into exp_save to avoid recomputation for the softmax in the backward pass
30+
fn vec_logsumexp_save<const N: usize>(x: &[Self; N], exp_save: &mut [Self; N], exp_sum: &mut Self) -> Self;
2931
fn symmetric_eigen<const N: usize>(
3032
matrix: na::SMatrix<Self, N, N>,
3133
) -> Option<(SVector<Self, N>, SMatrix<Self, N, N>)>;
@@ -57,6 +59,41 @@ impl FloatTrait for f32 {
5759
x[i] = x[i].scalar_exp();
5860
}
5961
}
62+
fn vec_logsumexp_save<const N: usize>(x: &[Self; N], exp_save: &mut [Self; N], exp_sum: &mut Self) -> Self {
63+
let blocks = N / 8;
64+
65+
let mut max = simd::f32x8::splat(f32::NEG_INFINITY);
66+
for i in 0..blocks {
67+
let a = simd::f32x8::from_slice(&x[i * 8..]);
68+
max = max.simd_max(a);
69+
}
70+
71+
if N % 8 != 0 {
72+
let last_elements =
73+
simd::f32x8::load_or(&x[blocks * 8..], simd::f32x8::splat(f32::NEG_INFINITY));
74+
max = max.simd_max(last_elements);
75+
}
76+
let max = max.reduce_max();
77+
78+
let mut sum = simd::f32x8::splat(0.0);
79+
for i in 0..blocks {
80+
let a = simd::f32x8::from_slice(&x[i * 8..]);
81+
let b = a - simd::f32x8::splat(max);
82+
let c = sleef::f32x::exp_u10(b);
83+
simd::f32x8::copy_to_slice(c, &mut exp_save[i * 8..]);
84+
sum += c;
85+
}
86+
if N % 8 != 0 {
87+
let last_elements =
88+
simd::f32x8::load_or(&x[blocks * 8..], simd::f32x8::splat(f32::NEG_INFINITY));
89+
let c = sleef::f32x::exp_u10(last_elements - simd::f32x8::splat(max));
90+
simd::f32x8::store_select(c, &mut exp_save[blocks * 8..], std::simd::Mask::splat(true));
91+
sum += c;
92+
}
93+
let sum = sum.reduce_sum();
94+
*exp_sum = sum;
95+
max + (sum).ln()
96+
}
6097
fn vec_logsumexp<const N: usize>(x: &[Self; N]) -> Self {
6198
let blocks = N / 8;
6299

@@ -124,6 +161,41 @@ impl FloatTrait for f64 {
124161
x[i] = x[i].scalar_exp();
125162
}
126163
}
164+
fn vec_logsumexp_save<const N: usize>(x: &[Self; N], exp_save: &mut [Self; N], exp_sum: &mut Self) -> Self {
165+
let blocks = N / 4;
166+
167+
let mut max = simd::f64x4::splat(f64::NEG_INFINITY);
168+
for i in 0..blocks {
169+
let a = simd::f64x4::from_slice(&x[i * 4..]);
170+
max = max.simd_max(a);
171+
}
172+
173+
if N % 4 != 0 {
174+
let last_elements =
175+
simd::f64x4::load_or(&x[blocks * 4..], simd::f64x4::splat(f64::NEG_INFINITY));
176+
max = max.simd_max(last_elements);
177+
}
178+
let max = max.reduce_max();
179+
180+
let mut sum = simd::f64x4::splat(0.0);
181+
for i in 0..blocks {
182+
let a = simd::f64x4::from_slice(&x[i * 4..]);
183+
let b = a - simd::f64x4::splat(max);
184+
let c = sleef::f64x::exp_u10(b);
185+
simd::f64x4::copy_to_slice(c, &mut exp_save[i * 4..]);
186+
sum += c;
187+
}
188+
if N % 4 != 0 {
189+
let last_elements =
190+
simd::f64x4::load_or(&x[blocks * 4..], simd::f64x4::splat(f64::NEG_INFINITY));
191+
let c = sleef::f64x::exp_u10(last_elements - simd::f64x4::splat(max));
192+
simd::f64x4::store_select(c, &mut exp_save[blocks * 4..], std::simd::Mask::splat(true));
193+
sum += c;
194+
}
195+
let sum = sum.reduce_sum();
196+
*exp_sum = sum;
197+
max + (sum).ln()
198+
}
127199
fn vec_logsumexp<const N: usize>(x: &[Self; N]) -> Self {
128200
let blocks = N / 4;
129201

phylo_grad/src/forward.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,32 @@ use crate::data_types::*;
22

33
use nalgebra as na;
44

5+
/// Forward data precomputed before the forward pass
56
pub struct ForwardData<F, const DIM: usize> {
67
pub log_transition: Vec<LogTransitionForwardData<F, DIM>>,
78
}
89

10+
/// Forward data which is saved during the forward pass
11+
pub struct ForwardDataSave<F, const DIM: usize> {
12+
pub logsumexp_exp_save: Vec<na::SMatrix<F, DIM, DIM>>,
13+
pub logsumexp_sum_save: Vec<na::SVector<F, DIM>>,
14+
}
15+
16+
impl<F : FloatTrait, const DIM: usize> ForwardDataSave<F, DIM> {
17+
pub fn new(capacity: usize) -> Self {
18+
Self {
19+
logsumexp_exp_save: vec![
20+
na::SMatrix::<F, DIM, DIM>::zeros();
21+
capacity
22+
],
23+
logsumexp_sum_save: vec![
24+
na::SVector::<F, DIM>::zeros();
25+
capacity
26+
],
27+
}
28+
}
29+
}
30+
931
impl<F, const DIM: usize> ForwardData<F, DIM> {
1032
pub fn with_capacity(capacity: usize) -> Self {
1133
Self {
@@ -152,15 +174,17 @@ pub fn forward_node<F: FloatTrait, const DIM: usize>(
152174
parent: usize,
153175
log_p: &mut [na::SVector<F, DIM>],
154176
forward_data: &ForwardData<F, DIM>,
177+
forward_data_save: &mut ForwardDataSave<F, DIM>,
155178
) {
179+
let logsumexp_exp_save = &mut forward_data_save.logsumexp_exp_save[child].data.0;
180+
let logsumexp_sum_save = forward_data_save.logsumexp_sum_save[child].as_mut_slice();
156181
/* log_p[parent]_a = logsumexp_b(log_p[child](b) + log_transition(rate_matrix, distance)(a, b) ) */
157182
for a in 0..DIM {
158183
let row_a = forward_data.log_transition[child].log_transition_T.column(a);
159184
let tmp = log_p[child] + row_a;
160185
unsafe {
161-
log_p[parent][a] += F::vec_logsumexp(std::mem::transmute::<&[[F; DIM]; 1], &[F; DIM]>(
162-
&tmp.data.0,
163-
))
186+
log_p[parent][a] += F::vec_logsumexp_save(std::mem::transmute::<&[[F; DIM]; 1], &[F; DIM]>(
187+
&tmp.data.0), &mut logsumexp_exp_save[a], &mut logsumexp_sum_save[a]);
164188
}
165189
}
166190
}

0 commit comments

Comments
 (0)