Skip to content

Commit cd8cbde

Browse files
committed
Initial commit
1 parent 1d1cce9 commit cd8cbde

File tree

5 files changed

+375
-1
lines changed

5 files changed

+375
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,6 @@ cython_debug/
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
#.idea/
163+
164+
uv.lock
165+
.python-version

README.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,29 @@
1-
# SOAP_JAX
1+
# SOAP_JAX
2+
3+
This is an *unofficial* JAX implementation of the SOAP optimizer from [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321), based on the official PyTorch implementation found here https://github.com/nikhilvyas/SOAP.
4+
5+
You can install this with
6+
```
7+
pip install git+https://github.com/haydn-jones/SOAP_JAX
8+
```
9+
10+
and can use it as follows:
11+
12+
```python
13+
from soap_jax import soap
14+
15+
opt = soap(
16+
learning_rate=3e-3,
17+
b1=0.95
18+
b2=0.95,
19+
weight_decay=0.01,
20+
precondition_frequency=5,
21+
)
22+
```
23+
24+
I've written it similarly to how optimizers in optax are defined, so you can also import `scale_by_soap` for just the gradient transformation.
25+
26+
## JAX Specific Information
27+
I did not implement merging of dimensions or optionally preconditioning <2D parameters. I'll gladly take PR's implementing these features, they just weren't necessary for me. Further, this is the first time I've implemented an optimizer in JAX so I'd be happy to take PR's improving its implementation as well.
28+
29+
The runs I've done with this implementation have gotten pretty good results so I expect that what I've done here is correct, but as always with unofficial implementations, review the code if you're going to do something important.

pyproject.toml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
[project]
2+
name = "soap-jax"
3+
version = "0.1.0"
4+
description = "SOAP Optimizer implemented in JAX"
5+
readme = "README.md"
6+
requires-python = ">=3.9"
7+
dependencies = [
8+
"jax",
9+
"jaxtyping",
10+
"optax",
11+
]
12+
13+
[build-system]
14+
requires = ["hatchling"]
15+
build-backend = "hatchling.build"
16+
17+
[tool.ruff]
18+
line-length = 120
19+
indent-width = 4
20+
target-version = "py39"
21+
22+
[tool.ruff.lint]
23+
select = ["E", "F", "B", "SIM", "I", "FURB"]
24+
ignore = ["B905"]
25+
fixable = ["ALL"]
26+
27+
[tool.ruff.format]
28+
quote-style = "double"
29+
indent-style = "space"

src/soap_jax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from soap_jax.soap import scale_by_soap as scale_by_soap
2+
from soap_jax.soap import soap as soap

src/soap_jax/soap.py

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
from itertools import chain
2+
from typing import List, NamedTuple, Union
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import jax.tree_util as jtu
7+
import optax
8+
import optax.tree_utils as otu
9+
from chex import Numeric
10+
from jaxtyping import Array
11+
from optax import GradientTransformation, Updates
12+
13+
14+
class SOAPState(NamedTuple):
15+
count: jnp.ndarray # type: ignore
16+
exp_avg: Updates
17+
exp_avg_sq: Updates
18+
GG: Updates
19+
Q: Updates
20+
21+
22+
def soap(
23+
learning_rate: optax.ScalarOrSchedule = 3e-3,
24+
b1: float = 0.95,
25+
b2: float = 0.95,
26+
shampoo_beta: float = -1,
27+
eps: float = 1e-8,
28+
weight_decay: float = 0.0,
29+
precondition_frequency: int = 10,
30+
max_precond_dim: int = 10000,
31+
precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST,
32+
) -> optax.GradientTransformationExtraArgs:
33+
"""
34+
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). Based on the original implementation at https://github.com/nikhilvyas/SOAP.
35+
36+
Args:
37+
learning_rate (optax.ScalarOrSchedule): The learning rate to use.
38+
b1 (float, optional): Adam's beta1 parameter. Defaults to 0.95.
39+
b2 (float, optional): Adam's beta2 parameter. Defaults to 0.95.
40+
shampoo_beta (float, optional): If >= 0, use this beta for the preconditioner (`L` and `R` in paper, `GG` below)
41+
moving average instead of b2. Defaults to -1.
42+
eps (float, optional): Adam's epsilon for numerical stability. Defaults to 1e-8.
43+
weight_decay (float, optional): Weight decay coefficient. Defaults to 0.0.
44+
precondition_frequency (int, optional): How often to update the preconditioner. Defaults to 10.
45+
max_precond_dim (int, optional): Maximum dimension of the preconditioner.
46+
Set to 10000 to exclude most common vocab sizes while including layers. Defaults to 10000.
47+
48+
Returns:
49+
optax.GradientTransformationExtraArgs: The SOAP optimizer.
50+
"""
51+
return optax.chain(
52+
scale_by_soap(
53+
b1=b1,
54+
b2=b2,
55+
shampoo_beta=shampoo_beta,
56+
eps=eps,
57+
precondition_frequency=precondition_frequency,
58+
max_precond_dim=max_precond_dim,
59+
precision=precision,
60+
),
61+
optax.add_decayed_weights(weight_decay),
62+
optax.scale_by_learning_rate(learning_rate),
63+
)
64+
65+
66+
def scale_by_soap(
67+
b1: float = 0.95,
68+
b2: float = 0.95,
69+
shampoo_beta: float = -1,
70+
eps: float = 1e-8,
71+
precondition_frequency: int = 10,
72+
max_precond_dim: int = 10000,
73+
precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST,
74+
) -> GradientTransformation:
75+
shampoo_beta = shampoo_beta if shampoo_beta >= 0 else b2
76+
77+
def init_fn(params: Updates) -> SOAPState:
78+
exp_avg = otu.tree_zeros_like(params)
79+
exp_avg_sq = otu.tree_zeros_like(params)
80+
GG = jtu.tree_map(
81+
lambda p: init_conditioner(p, max_precond_dim),
82+
params,
83+
)
84+
Q = jtu.tree_map(
85+
lambda p: init_conditioner(p, max_precond_dim),
86+
params,
87+
)
88+
return SOAPState(
89+
count=jnp.zeros([], jnp.int32),
90+
exp_avg=exp_avg,
91+
exp_avg_sq=exp_avg_sq,
92+
GG=GG,
93+
Q=Q,
94+
)
95+
96+
def init_step(
97+
updates: Updates,
98+
state: SOAPState,
99+
) -> tuple[Updates, SOAPState]:
100+
new_GG = jtu.tree_map(
101+
lambda grad, gg: update_preconditioner(grad, gg, shampoo_beta),
102+
updates,
103+
state.GG,
104+
)
105+
106+
new_Q = jtu.tree_map(
107+
lambda gg: get_orthogonal_matrix(gg),
108+
new_GG,
109+
)
110+
111+
# Replace updates with zeros
112+
new_updates = otu.tree_zeros_like(updates)
113+
114+
return new_updates, state._replace(GG=new_GG, Q=new_Q)
115+
116+
def update_step(
117+
updates: Updates,
118+
state: SOAPState,
119+
) -> tuple[Updates, SOAPState]:
120+
# Project gradients
121+
grad_projected = jtu.tree_map(
122+
lambda grad, q: project(grad, q),
123+
updates,
124+
state.Q,
125+
)
126+
127+
# Update moments
128+
exp_avg = otu.tree_update_moment(updates, state.exp_avg, b1, 1)
129+
exp_avg_sq = otu.tree_update_moment_per_elem_norm(grad_projected, state.exp_avg_sq, b2, 2)
130+
131+
exp_avg_projected = jtu.tree_map(
132+
lambda e, q: project(e, q),
133+
exp_avg,
134+
state.Q,
135+
)
136+
137+
# Project back
138+
norm_updates = jtu.tree_map(
139+
lambda e_avg, e_avg_sq, q: project_back(e_avg / (jnp.sqrt(e_avg_sq) + eps), q),
140+
exp_avg_projected,
141+
exp_avg_sq,
142+
state.Q,
143+
)
144+
145+
bc1 = 1 - b1**state.count
146+
bc2 = 1 - b2**state.count
147+
corr = jnp.sqrt(bc2) / bc1
148+
149+
# Bias correction on the updates
150+
norm_updates = jtu.tree_map(
151+
lambda p: p * corr,
152+
norm_updates,
153+
)
154+
155+
# Update the preconditioner
156+
new_GG = jtu.tree_map(
157+
lambda grad, gg: update_preconditioner(grad, gg, shampoo_beta),
158+
updates,
159+
state.GG,
160+
)
161+
162+
# Update the orthogonal matrix / exp_avg_sq
163+
new_Q_and_exp_avg_sq = jax.lax.cond(
164+
state.count % precondition_frequency == 0,
165+
lambda: jtu.tree_map(
166+
lambda e, gg, q: get_orthogonal_matrix_QR(gg, q, e),
167+
exp_avg_sq,
168+
new_GG,
169+
state.Q,
170+
),
171+
lambda: jtu.tree_map(
172+
lambda e, q: (q, e),
173+
state.exp_avg_sq,
174+
state.Q,
175+
),
176+
)
177+
## Unpack the results
178+
new_Q = jtu.tree_map(
179+
lambda _, x: x[0],
180+
updates,
181+
new_Q_and_exp_avg_sq,
182+
)
183+
exp_avg_sq = jtu.tree_map(
184+
lambda _, x: x[1],
185+
updates,
186+
new_Q_and_exp_avg_sq,
187+
)
188+
189+
new_state = SOAPState(
190+
count=state.count,
191+
exp_avg=exp_avg,
192+
exp_avg_sq=exp_avg_sq,
193+
GG=new_GG,
194+
Q=new_Q,
195+
)
196+
197+
return norm_updates, new_state
198+
199+
def update_fn(updates: Updates, state: SOAPState, params: Updates | None = None) -> tuple[Updates, SOAPState]:
200+
del params
201+
count_inc = jnp.asarray(optax.safe_int32_increment(state.count))
202+
state = state._replace(count=count_inc)
203+
204+
with jax.default_matmul_precision(precision):
205+
updates, new_state = jax.lax.cond(
206+
count_inc == 1,
207+
lambda: init_step(updates, state),
208+
lambda: update_step(updates, state),
209+
)
210+
211+
return updates, new_state
212+
213+
return optax.GradientTransformation(init_fn, update_fn) # type: ignore
214+
215+
216+
def update_preconditioner(
217+
grad: Array,
218+
GG: List[Union[Array, None]],
219+
beta: float,
220+
) -> List[Union[Array, None]]:
221+
if grad.ndim == 1:
222+
return [lerp(GG[0], jnp.outer(grad, grad), 1 - beta)] # type: ignore
223+
224+
new_GG = []
225+
for idx, gg in enumerate(GG):
226+
if gg is None:
227+
new_GG.append(None)
228+
continue
229+
230+
outer_product = jnp.tensordot(
231+
grad,
232+
grad,
233+
axes=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
234+
)
235+
new_GG.append(lerp(gg, outer_product, 1 - beta))
236+
237+
return new_GG
238+
239+
240+
def project(grad: Array, Q: List[Union[Array, None]]) -> Array:
241+
for mat in Q:
242+
if mat is not None: # noqa: SIM108
243+
grad = jnp.tensordot(
244+
grad,
245+
mat,
246+
axes=((0,), (0,)),
247+
)
248+
else:
249+
permute_order = list(range(1, len(grad.shape))) + [0]
250+
grad = jnp.transpose(grad, permute_order)
251+
252+
return grad
253+
254+
255+
def project_back(grad: Array, Q: List[Union[Array, None]]) -> Array:
256+
for mat in Q:
257+
if mat is not None: # noqa: SIM108
258+
grad = jnp.tensordot(
259+
grad,
260+
mat,
261+
axes=((0,), (1,)),
262+
)
263+
else:
264+
grad = jnp.moveaxis(grad, 0, -1)
265+
266+
return grad
267+
268+
269+
def get_orthogonal_matrix(gg: Array) -> Union[Array, None]:
270+
if gg is None:
271+
return None
272+
273+
_, eigh = jnp.linalg.eigh(gg + 1e-30 * jnp.eye(gg.shape[0]))
274+
return jnp.flip(eigh, axis=1)
275+
276+
277+
def get_orthogonal_matrix_QR(
278+
GG: List[Union[Array, None]],
279+
Q: List[Union[Array, None]],
280+
exp_avg_sq: Array,
281+
) -> tuple[List[Union[Array, None]], Array]:
282+
final_Q = []
283+
for ind, (m, o) in enumerate(zip(GG, Q)):
284+
if m is None or o is None:
285+
final_Q.append(None)
286+
continue
287+
288+
est_eig = jnp.diag(o.T @ m @ o)
289+
sort_idx = jnp.argsort(est_eig, descending=True)
290+
exp_avg_sq = jnp.take(exp_avg_sq, sort_idx, axis=ind)
291+
o = o[:, sort_idx]
292+
power_iter = m @ o
293+
Q_new, _ = jnp.linalg.qr(power_iter)
294+
295+
final_Q.append(Q_new)
296+
297+
return final_Q, exp_avg_sq
298+
299+
300+
def lerp(
301+
start: Array,
302+
end: Array,
303+
weight: Numeric,
304+
):
305+
return start + weight * (end - start)
306+
307+
308+
def init_conditioner(p: Array, max_precond_dim: int) -> List[Union[Array, None]]:
309+
if p.ndim == 1:
310+
return [jnp.zeros((p.shape[0], p.shape[0]))]
311+
312+
return [jnp.zeros((s, s)) if s <= max_precond_dim else None for s in p.shape]

0 commit comments

Comments
 (0)