11from itertools import chain
2- from typing import List , NamedTuple , Union
2+ from typing import List , NamedTuple , Optional , Union
33
44import jax
55import jax .numpy as jnp
@@ -44,6 +44,7 @@ def soap(
4444 precondition_frequency (int, optional): How often to update the preconditioner. Defaults to 10.
4545 max_precond_dim (int, optional): Maximum dimension of the preconditioner.
4646 Set to 10000 to exclude most common vocab sizes while including layers. Defaults to 10000.
47+ precision (jax.lax.PrecisionLike, optional): Precision to use. Defaults to jax.lax.Precision.HIGHEST.
4748
4849 Returns:
4950 optax.GradientTransformationExtraArgs: The SOAP optimizer.
@@ -72,6 +73,23 @@ def scale_by_soap(
7273 max_precond_dim : int = 10000 ,
7374 precision : jax .lax .PrecisionLike = jax .lax .Precision .HIGHEST ,
7475) -> GradientTransformation :
76+ """
77+ Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). Based on the original implementation at https://github.com/nikhilvyas/SOAP.
78+
79+ Args:
80+ b1 (float, optional): Adam's beta1 parameter. Defaults to 0.95.
81+ b2 (float, optional): Adam's beta2 parameter. Defaults to 0.95.
82+ shampoo_beta (float, optional): If >= 0, use this beta for the preconditioner (`L` and `R` in paper, `GG` below)
83+ moving average instead of b2. Defaults to -1.
84+ eps (float, optional): Adam's epsilon for numerical stability. Defaults to 1e-8.
85+ precondition_frequency (int, optional): How often to update the preconditioner. Defaults to 10.
86+ max_precond_dim (int, optional): Maximum dimension of the preconditioner.
87+ Set to 10000 to exclude most common vocab sizes while including layers. Defaults to 10000.
88+ precision (jax.lax.PrecisionLike, optional): Precision to use. Defaults to jax.lax.Precision.H
89+
90+ Returns:
91+ optax.GradientTransformationExtraArgs: The SOAP optimizer.
92+ """
7593 shampoo_beta = shampoo_beta if shampoo_beta >= 0 else b2
7694
7795 def init_fn (params : Updates ) -> SOAPState :
@@ -119,7 +137,7 @@ def update_step(
119137 ) -> tuple [Updates , SOAPState ]:
120138 # Project gradients
121139 grad_projected = jtu .tree_map (
122- lambda grad , q : project (grad , q ),
140+ lambda grad , q : project (grad , q , precision ),
123141 updates ,
124142 state .Q ,
125143 )
@@ -129,14 +147,14 @@ def update_step(
129147 exp_avg_sq = otu .tree_update_moment_per_elem_norm (grad_projected , state .exp_avg_sq , b2 , 2 )
130148
131149 exp_avg_projected = jtu .tree_map (
132- lambda e , q : project (e , q ),
150+ lambda e , q : project (e , q , precision ),
133151 exp_avg ,
134152 state .Q ,
135153 )
136154
137155 # Project back
138156 norm_updates = jtu .tree_map (
139- lambda e_avg , e_avg_sq , q : project_back (e_avg / (jnp .sqrt (e_avg_sq ) + eps ), q ),
157+ lambda e_avg , e_avg_sq , q : project_back (e_avg / (jnp .sqrt (e_avg_sq ) + eps ), q , precision ),
140158 exp_avg_projected ,
141159 exp_avg_sq ,
142160 state .Q ,
@@ -154,7 +172,7 @@ def update_step(
154172
155173 # Update the preconditioner
156174 new_GG = jtu .tree_map (
157- lambda grad , gg : update_preconditioner (grad , gg , shampoo_beta ),
175+ lambda grad , gg : update_preconditioner (grad , gg , shampoo_beta , precision ),
158176 updates ,
159177 state .GG ,
160178 )
@@ -163,7 +181,7 @@ def update_step(
163181 new_Q_and_exp_avg_sq = jax .lax .cond (
164182 state .count % precondition_frequency == 0 ,
165183 lambda : jtu .tree_map (
166- lambda e , gg , q : get_orthogonal_matrix_QR (gg , q , e ),
184+ lambda e , gg , q : get_orthogonal_matrix_QR (gg , q , e , precision ),
167185 exp_avg_sq ,
168186 new_GG ,
169187 state .Q ,
@@ -196,17 +214,16 @@ def update_step(
196214
197215 return norm_updates , new_state
198216
199- def update_fn (updates : Updates , state : SOAPState , params : Updates | None = None ) -> tuple [Updates , SOAPState ]:
217+ def update_fn (updates : Updates , state : SOAPState , params : Optional [ Updates ] = None ) -> tuple [Updates , SOAPState ]:
200218 del params
201219 count_inc = jnp .asarray (optax .safe_int32_increment (state .count ))
202220 state = state ._replace (count = count_inc )
203221
204- with jax .default_matmul_precision (precision ):
205- updates , new_state = jax .lax .cond (
206- count_inc == 1 ,
207- lambda : init_step (updates , state ),
208- lambda : update_step (updates , state ),
209- )
222+ updates , new_state = jax .lax .cond (
223+ count_inc == 1 ,
224+ lambda : init_step (updates , state ),
225+ lambda : update_step (updates , state ),
226+ )
210227
211228 return updates , new_state
212229
@@ -217,9 +234,10 @@ def update_preconditioner(
217234 grad : Array ,
218235 GG : List [Union [Array , None ]],
219236 beta : float ,
237+ precision : jax .lax .PrecisionLike = jax .lax .Precision .HIGHEST ,
220238) -> List [Union [Array , None ]]:
221239 if grad .ndim == 1 :
222- return [lerp (GG [0 ], jnp .outer (grad , grad ), 1 - beta )] # type: ignore
240+ return [lerp (GG [0 ], jnp .matmul (grad [:, None ], grad [ None , :], precision = precision ), 1 - beta )] # type: ignore
223241
224242 new_GG = []
225243 for idx , gg in enumerate (GG ):
@@ -231,19 +249,25 @@ def update_preconditioner(
231249 grad ,
232250 grad ,
233251 axes = [[* chain (range (idx ), range (idx + 1 , len (grad .shape )))]] * 2 ,
252+ precision = precision ,
234253 )
235254 new_GG .append (lerp (gg , outer_product , 1 - beta ))
236255
237256 return new_GG
238257
239258
240- def project (grad : Array , Q : List [Union [Array , None ]]) -> Array :
259+ def project (
260+ grad : Array ,
261+ Q : List [Union [Array , None ]],
262+ precision : jax .lax .PrecisionLike = jax .lax .Precision .HIGHEST ,
263+ ) -> Array :
241264 for mat in Q :
242265 if mat is not None : # noqa: SIM108
243266 grad = jnp .tensordot (
244267 grad ,
245268 mat ,
246269 axes = ((0 ,), (0 ,)),
270+ precision = precision ,
247271 )
248272 else :
249273 permute_order = list (range (1 , len (grad .shape ))) + [0 ]
@@ -252,13 +276,18 @@ def project(grad: Array, Q: List[Union[Array, None]]) -> Array:
252276 return grad
253277
254278
255- def project_back (grad : Array , Q : List [Union [Array , None ]]) -> Array :
279+ def project_back (
280+ grad : Array ,
281+ Q : List [Union [Array , None ]],
282+ precision : jax .lax .PrecisionLike = jax .lax .Precision .HIGHEST ,
283+ ) -> Array :
256284 for mat in Q :
257285 if mat is not None : # noqa: SIM108
258286 grad = jnp .tensordot (
259287 grad ,
260288 mat ,
261289 axes = ((0 ,), (1 ,)),
290+ precision = precision ,
262291 )
263292 else :
264293 grad = jnp .moveaxis (grad , 0 , - 1 )
@@ -278,18 +307,25 @@ def get_orthogonal_matrix_QR(
278307 GG : List [Union [Array , None ]],
279308 Q : List [Union [Array , None ]],
280309 exp_avg_sq : Array ,
310+ precision : jax .lax .PrecisionLike = jax .lax .Precision .HIGHEST ,
281311) -> tuple [List [Union [Array , None ]], Array ]:
282312 final_Q = []
283313 for ind , (m , o ) in enumerate (zip (GG , Q )):
284314 if m is None or o is None :
285315 final_Q .append (None )
286316 continue
287317
288- est_eig = jnp .diag (o .T @ m @ o )
318+ est_eig = jnp .diag (
319+ jnp .matmul (
320+ jnp .matmul (o .T , m , precision = precision ),
321+ o ,
322+ precision = precision ,
323+ )
324+ )
289325 sort_idx = jnp .argsort (est_eig , descending = True )
290326 exp_avg_sq = jnp .take (exp_avg_sq , sort_idx , axis = ind )
291327 o = o [:, sort_idx ]
292- power_iter = m @ o
328+ power_iter = jnp . matmul ( m , o , precision = precision )
293329 Q_new , _ = jnp .linalg .qr (power_iter )
294330
295331 final_Q .append (Q_new )
0 commit comments