Skip to content

Commit 2c0c34f

Browse files
committed
Fix precision type, Union instead of |
1 parent cd8cbde commit 2c0c34f

File tree

1 file changed

+54
-18
lines changed

1 file changed

+54
-18
lines changed

src/soap_jax/soap.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from itertools import chain
2-
from typing import List, NamedTuple, Union
2+
from typing import List, NamedTuple, Optional, Union
33

44
import jax
55
import jax.numpy as jnp
@@ -44,6 +44,7 @@ def soap(
4444
precondition_frequency (int, optional): How often to update the preconditioner. Defaults to 10.
4545
max_precond_dim (int, optional): Maximum dimension of the preconditioner.
4646
Set to 10000 to exclude most common vocab sizes while including layers. Defaults to 10000.
47+
precision (jax.lax.PrecisionLike, optional): Precision to use. Defaults to jax.lax.Precision.HIGHEST.
4748
4849
Returns:
4950
optax.GradientTransformationExtraArgs: The SOAP optimizer.
@@ -72,6 +73,23 @@ def scale_by_soap(
7273
max_precond_dim: int = 10000,
7374
precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST,
7475
) -> GradientTransformation:
76+
"""
77+
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). Based on the original implementation at https://github.com/nikhilvyas/SOAP.
78+
79+
Args:
80+
b1 (float, optional): Adam's beta1 parameter. Defaults to 0.95.
81+
b2 (float, optional): Adam's beta2 parameter. Defaults to 0.95.
82+
shampoo_beta (float, optional): If >= 0, use this beta for the preconditioner (`L` and `R` in paper, `GG` below)
83+
moving average instead of b2. Defaults to -1.
84+
eps (float, optional): Adam's epsilon for numerical stability. Defaults to 1e-8.
85+
precondition_frequency (int, optional): How often to update the preconditioner. Defaults to 10.
86+
max_precond_dim (int, optional): Maximum dimension of the preconditioner.
87+
Set to 10000 to exclude most common vocab sizes while including layers. Defaults to 10000.
88+
precision (jax.lax.PrecisionLike, optional): Precision to use. Defaults to jax.lax.Precision.H
89+
90+
Returns:
91+
optax.GradientTransformationExtraArgs: The SOAP optimizer.
92+
"""
7593
shampoo_beta = shampoo_beta if shampoo_beta >= 0 else b2
7694

7795
def init_fn(params: Updates) -> SOAPState:
@@ -119,7 +137,7 @@ def update_step(
119137
) -> tuple[Updates, SOAPState]:
120138
# Project gradients
121139
grad_projected = jtu.tree_map(
122-
lambda grad, q: project(grad, q),
140+
lambda grad, q: project(grad, q, precision),
123141
updates,
124142
state.Q,
125143
)
@@ -129,14 +147,14 @@ def update_step(
129147
exp_avg_sq = otu.tree_update_moment_per_elem_norm(grad_projected, state.exp_avg_sq, b2, 2)
130148

131149
exp_avg_projected = jtu.tree_map(
132-
lambda e, q: project(e, q),
150+
lambda e, q: project(e, q, precision),
133151
exp_avg,
134152
state.Q,
135153
)
136154

137155
# Project back
138156
norm_updates = jtu.tree_map(
139-
lambda e_avg, e_avg_sq, q: project_back(e_avg / (jnp.sqrt(e_avg_sq) + eps), q),
157+
lambda e_avg, e_avg_sq, q: project_back(e_avg / (jnp.sqrt(e_avg_sq) + eps), q, precision),
140158
exp_avg_projected,
141159
exp_avg_sq,
142160
state.Q,
@@ -154,7 +172,7 @@ def update_step(
154172

155173
# Update the preconditioner
156174
new_GG = jtu.tree_map(
157-
lambda grad, gg: update_preconditioner(grad, gg, shampoo_beta),
175+
lambda grad, gg: update_preconditioner(grad, gg, shampoo_beta, precision),
158176
updates,
159177
state.GG,
160178
)
@@ -163,7 +181,7 @@ def update_step(
163181
new_Q_and_exp_avg_sq = jax.lax.cond(
164182
state.count % precondition_frequency == 0,
165183
lambda: jtu.tree_map(
166-
lambda e, gg, q: get_orthogonal_matrix_QR(gg, q, e),
184+
lambda e, gg, q: get_orthogonal_matrix_QR(gg, q, e, precision),
167185
exp_avg_sq,
168186
new_GG,
169187
state.Q,
@@ -196,17 +214,16 @@ def update_step(
196214

197215
return norm_updates, new_state
198216

199-
def update_fn(updates: Updates, state: SOAPState, params: Updates | None = None) -> tuple[Updates, SOAPState]:
217+
def update_fn(updates: Updates, state: SOAPState, params: Optional[Updates] = None) -> tuple[Updates, SOAPState]:
200218
del params
201219
count_inc = jnp.asarray(optax.safe_int32_increment(state.count))
202220
state = state._replace(count=count_inc)
203221

204-
with jax.default_matmul_precision(precision):
205-
updates, new_state = jax.lax.cond(
206-
count_inc == 1,
207-
lambda: init_step(updates, state),
208-
lambda: update_step(updates, state),
209-
)
222+
updates, new_state = jax.lax.cond(
223+
count_inc == 1,
224+
lambda: init_step(updates, state),
225+
lambda: update_step(updates, state),
226+
)
210227

211228
return updates, new_state
212229

@@ -217,9 +234,10 @@ def update_preconditioner(
217234
grad: Array,
218235
GG: List[Union[Array, None]],
219236
beta: float,
237+
precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST,
220238
) -> List[Union[Array, None]]:
221239
if grad.ndim == 1:
222-
return [lerp(GG[0], jnp.outer(grad, grad), 1 - beta)] # type: ignore
240+
return [lerp(GG[0], jnp.matmul(grad[:, None], grad[None, :], precision=precision), 1 - beta)] # type: ignore
223241

224242
new_GG = []
225243
for idx, gg in enumerate(GG):
@@ -231,19 +249,25 @@ def update_preconditioner(
231249
grad,
232250
grad,
233251
axes=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
252+
precision=precision,
234253
)
235254
new_GG.append(lerp(gg, outer_product, 1 - beta))
236255

237256
return new_GG
238257

239258

240-
def project(grad: Array, Q: List[Union[Array, None]]) -> Array:
259+
def project(
260+
grad: Array,
261+
Q: List[Union[Array, None]],
262+
precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST,
263+
) -> Array:
241264
for mat in Q:
242265
if mat is not None: # noqa: SIM108
243266
grad = jnp.tensordot(
244267
grad,
245268
mat,
246269
axes=((0,), (0,)),
270+
precision=precision,
247271
)
248272
else:
249273
permute_order = list(range(1, len(grad.shape))) + [0]
@@ -252,13 +276,18 @@ def project(grad: Array, Q: List[Union[Array, None]]) -> Array:
252276
return grad
253277

254278

255-
def project_back(grad: Array, Q: List[Union[Array, None]]) -> Array:
279+
def project_back(
280+
grad: Array,
281+
Q: List[Union[Array, None]],
282+
precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST,
283+
) -> Array:
256284
for mat in Q:
257285
if mat is not None: # noqa: SIM108
258286
grad = jnp.tensordot(
259287
grad,
260288
mat,
261289
axes=((0,), (1,)),
290+
precision=precision,
262291
)
263292
else:
264293
grad = jnp.moveaxis(grad, 0, -1)
@@ -278,18 +307,25 @@ def get_orthogonal_matrix_QR(
278307
GG: List[Union[Array, None]],
279308
Q: List[Union[Array, None]],
280309
exp_avg_sq: Array,
310+
precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST,
281311
) -> tuple[List[Union[Array, None]], Array]:
282312
final_Q = []
283313
for ind, (m, o) in enumerate(zip(GG, Q)):
284314
if m is None or o is None:
285315
final_Q.append(None)
286316
continue
287317

288-
est_eig = jnp.diag(o.T @ m @ o)
318+
est_eig = jnp.diag(
319+
jnp.matmul(
320+
jnp.matmul(o.T, m, precision=precision),
321+
o,
322+
precision=precision,
323+
)
324+
)
289325
sort_idx = jnp.argsort(est_eig, descending=True)
290326
exp_avg_sq = jnp.take(exp_avg_sq, sort_idx, axis=ind)
291327
o = o[:, sort_idx]
292-
power_iter = m @ o
328+
power_iter = jnp.matmul(m, o, precision=precision)
293329
Q_new, _ = jnp.linalg.qr(power_iter)
294330

295331
final_Q.append(Q_new)

0 commit comments

Comments
 (0)