|
| 1 | +from itertools import chain |
| 2 | +from typing import List, NamedTuple, Union |
| 3 | + |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | +import jax.tree_util as jtu |
| 7 | +import optax |
| 8 | +import optax.tree_utils as otu |
| 9 | +from chex import Numeric |
| 10 | +from jaxtyping import Array |
| 11 | +from optax import GradientTransformation, Updates |
| 12 | + |
| 13 | + |
| 14 | +class SOAPState(NamedTuple): |
| 15 | + count: jnp.ndarray # type: ignore |
| 16 | + exp_avg: Updates |
| 17 | + exp_avg_sq: Updates |
| 18 | + GG: Updates |
| 19 | + Q: Updates |
| 20 | + |
| 21 | + |
| 22 | +def soap( |
| 23 | + learning_rate: optax.ScalarOrSchedule = 3e-3, |
| 24 | + b1: float = 0.95, |
| 25 | + b2: float = 0.95, |
| 26 | + shampoo_beta: float = -1, |
| 27 | + eps: float = 1e-8, |
| 28 | + weight_decay: float = 0.0, |
| 29 | + precondition_frequency: int = 10, |
| 30 | + max_precond_dim: int = 10000, |
| 31 | + precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST, |
| 32 | +) -> optax.GradientTransformationExtraArgs: |
| 33 | + """ |
| 34 | + Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). Based on the original implementation at https://github.com/nikhilvyas/SOAP. |
| 35 | +
|
| 36 | + Args: |
| 37 | + learning_rate (optax.ScalarOrSchedule): The learning rate to use. |
| 38 | + b1 (float, optional): Adam's beta1 parameter. Defaults to 0.95. |
| 39 | + b2 (float, optional): Adam's beta2 parameter. Defaults to 0.95. |
| 40 | + shampoo_beta (float, optional): If >= 0, use this beta for the preconditioner (`L` and `R` in paper, `GG` below) |
| 41 | + moving average instead of b2. Defaults to -1. |
| 42 | + eps (float, optional): Adam's epsilon for numerical stability. Defaults to 1e-8. |
| 43 | + weight_decay (float, optional): Weight decay coefficient. Defaults to 0.0. |
| 44 | + precondition_frequency (int, optional): How often to update the preconditioner. Defaults to 10. |
| 45 | + max_precond_dim (int, optional): Maximum dimension of the preconditioner. |
| 46 | + Set to 10000 to exclude most common vocab sizes while including layers. Defaults to 10000. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + optax.GradientTransformationExtraArgs: The SOAP optimizer. |
| 50 | + """ |
| 51 | + return optax.chain( |
| 52 | + scale_by_soap( |
| 53 | + b1=b1, |
| 54 | + b2=b2, |
| 55 | + shampoo_beta=shampoo_beta, |
| 56 | + eps=eps, |
| 57 | + precondition_frequency=precondition_frequency, |
| 58 | + max_precond_dim=max_precond_dim, |
| 59 | + precision=precision, |
| 60 | + ), |
| 61 | + optax.add_decayed_weights(weight_decay), |
| 62 | + optax.scale_by_learning_rate(learning_rate), |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +def scale_by_soap( |
| 67 | + b1: float = 0.95, |
| 68 | + b2: float = 0.95, |
| 69 | + shampoo_beta: float = -1, |
| 70 | + eps: float = 1e-8, |
| 71 | + precondition_frequency: int = 10, |
| 72 | + max_precond_dim: int = 10000, |
| 73 | + precision: jax.lax.PrecisionLike = jax.lax.Precision.HIGHEST, |
| 74 | +) -> GradientTransformation: |
| 75 | + shampoo_beta = shampoo_beta if shampoo_beta >= 0 else b2 |
| 76 | + |
| 77 | + def init_fn(params: Updates) -> SOAPState: |
| 78 | + exp_avg = otu.tree_zeros_like(params) |
| 79 | + exp_avg_sq = otu.tree_zeros_like(params) |
| 80 | + GG = jtu.tree_map( |
| 81 | + lambda p: init_conditioner(p, max_precond_dim), |
| 82 | + params, |
| 83 | + ) |
| 84 | + Q = jtu.tree_map( |
| 85 | + lambda p: init_conditioner(p, max_precond_dim), |
| 86 | + params, |
| 87 | + ) |
| 88 | + return SOAPState( |
| 89 | + count=jnp.zeros([], jnp.int32), |
| 90 | + exp_avg=exp_avg, |
| 91 | + exp_avg_sq=exp_avg_sq, |
| 92 | + GG=GG, |
| 93 | + Q=Q, |
| 94 | + ) |
| 95 | + |
| 96 | + def init_step( |
| 97 | + updates: Updates, |
| 98 | + state: SOAPState, |
| 99 | + ) -> tuple[Updates, SOAPState]: |
| 100 | + new_GG = jtu.tree_map( |
| 101 | + lambda grad, gg: update_preconditioner(grad, gg, shampoo_beta), |
| 102 | + updates, |
| 103 | + state.GG, |
| 104 | + ) |
| 105 | + |
| 106 | + new_Q = jtu.tree_map( |
| 107 | + lambda gg: get_orthogonal_matrix(gg), |
| 108 | + new_GG, |
| 109 | + ) |
| 110 | + |
| 111 | + # Replace updates with zeros |
| 112 | + new_updates = otu.tree_zeros_like(updates) |
| 113 | + |
| 114 | + return new_updates, state._replace(GG=new_GG, Q=new_Q) |
| 115 | + |
| 116 | + def update_step( |
| 117 | + updates: Updates, |
| 118 | + state: SOAPState, |
| 119 | + ) -> tuple[Updates, SOAPState]: |
| 120 | + # Project gradients |
| 121 | + grad_projected = jtu.tree_map( |
| 122 | + lambda grad, q: project(grad, q), |
| 123 | + updates, |
| 124 | + state.Q, |
| 125 | + ) |
| 126 | + |
| 127 | + # Update moments |
| 128 | + exp_avg = otu.tree_update_moment(updates, state.exp_avg, b1, 1) |
| 129 | + exp_avg_sq = otu.tree_update_moment_per_elem_norm(grad_projected, state.exp_avg_sq, b2, 2) |
| 130 | + |
| 131 | + exp_avg_projected = jtu.tree_map( |
| 132 | + lambda e, q: project(e, q), |
| 133 | + exp_avg, |
| 134 | + state.Q, |
| 135 | + ) |
| 136 | + |
| 137 | + # Project back |
| 138 | + norm_updates = jtu.tree_map( |
| 139 | + lambda e_avg, e_avg_sq, q: project_back(e_avg / (jnp.sqrt(e_avg_sq) + eps), q), |
| 140 | + exp_avg_projected, |
| 141 | + exp_avg_sq, |
| 142 | + state.Q, |
| 143 | + ) |
| 144 | + |
| 145 | + bc1 = 1 - b1**state.count |
| 146 | + bc2 = 1 - b2**state.count |
| 147 | + corr = jnp.sqrt(bc2) / bc1 |
| 148 | + |
| 149 | + # Bias correction on the updates |
| 150 | + norm_updates = jtu.tree_map( |
| 151 | + lambda p: p * corr, |
| 152 | + norm_updates, |
| 153 | + ) |
| 154 | + |
| 155 | + # Update the preconditioner |
| 156 | + new_GG = jtu.tree_map( |
| 157 | + lambda grad, gg: update_preconditioner(grad, gg, shampoo_beta), |
| 158 | + updates, |
| 159 | + state.GG, |
| 160 | + ) |
| 161 | + |
| 162 | + # Update the orthogonal matrix / exp_avg_sq |
| 163 | + new_Q_and_exp_avg_sq = jax.lax.cond( |
| 164 | + state.count % precondition_frequency == 0, |
| 165 | + lambda: jtu.tree_map( |
| 166 | + lambda e, gg, q: get_orthogonal_matrix_QR(gg, q, e), |
| 167 | + exp_avg_sq, |
| 168 | + new_GG, |
| 169 | + state.Q, |
| 170 | + ), |
| 171 | + lambda: jtu.tree_map( |
| 172 | + lambda e, q: (q, e), |
| 173 | + state.exp_avg_sq, |
| 174 | + state.Q, |
| 175 | + ), |
| 176 | + ) |
| 177 | + ## Unpack the results |
| 178 | + new_Q = jtu.tree_map( |
| 179 | + lambda _, x: x[0], |
| 180 | + updates, |
| 181 | + new_Q_and_exp_avg_sq, |
| 182 | + ) |
| 183 | + exp_avg_sq = jtu.tree_map( |
| 184 | + lambda _, x: x[1], |
| 185 | + updates, |
| 186 | + new_Q_and_exp_avg_sq, |
| 187 | + ) |
| 188 | + |
| 189 | + new_state = SOAPState( |
| 190 | + count=state.count, |
| 191 | + exp_avg=exp_avg, |
| 192 | + exp_avg_sq=exp_avg_sq, |
| 193 | + GG=new_GG, |
| 194 | + Q=new_Q, |
| 195 | + ) |
| 196 | + |
| 197 | + return norm_updates, new_state |
| 198 | + |
| 199 | + def update_fn(updates: Updates, state: SOAPState, params: Updates | None = None) -> tuple[Updates, SOAPState]: |
| 200 | + del params |
| 201 | + count_inc = jnp.asarray(optax.safe_int32_increment(state.count)) |
| 202 | + state = state._replace(count=count_inc) |
| 203 | + |
| 204 | + with jax.default_matmul_precision(precision): |
| 205 | + updates, new_state = jax.lax.cond( |
| 206 | + count_inc == 1, |
| 207 | + lambda: init_step(updates, state), |
| 208 | + lambda: update_step(updates, state), |
| 209 | + ) |
| 210 | + |
| 211 | + return updates, new_state |
| 212 | + |
| 213 | + return optax.GradientTransformation(init_fn, update_fn) # type: ignore |
| 214 | + |
| 215 | + |
| 216 | +def update_preconditioner( |
| 217 | + grad: Array, |
| 218 | + GG: List[Union[Array, None]], |
| 219 | + beta: float, |
| 220 | +) -> List[Union[Array, None]]: |
| 221 | + if grad.ndim == 1: |
| 222 | + return [lerp(GG[0], jnp.outer(grad, grad), 1 - beta)] # type: ignore |
| 223 | + |
| 224 | + new_GG = [] |
| 225 | + for idx, gg in enumerate(GG): |
| 226 | + if gg is None: |
| 227 | + new_GG.append(None) |
| 228 | + continue |
| 229 | + |
| 230 | + outer_product = jnp.tensordot( |
| 231 | + grad, |
| 232 | + grad, |
| 233 | + axes=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2, |
| 234 | + ) |
| 235 | + new_GG.append(lerp(gg, outer_product, 1 - beta)) |
| 236 | + |
| 237 | + return new_GG |
| 238 | + |
| 239 | + |
| 240 | +def project(grad: Array, Q: List[Union[Array, None]]) -> Array: |
| 241 | + for mat in Q: |
| 242 | + if mat is not None: # noqa: SIM108 |
| 243 | + grad = jnp.tensordot( |
| 244 | + grad, |
| 245 | + mat, |
| 246 | + axes=((0,), (0,)), |
| 247 | + ) |
| 248 | + else: |
| 249 | + permute_order = list(range(1, len(grad.shape))) + [0] |
| 250 | + grad = jnp.transpose(grad, permute_order) |
| 251 | + |
| 252 | + return grad |
| 253 | + |
| 254 | + |
| 255 | +def project_back(grad: Array, Q: List[Union[Array, None]]) -> Array: |
| 256 | + for mat in Q: |
| 257 | + if mat is not None: # noqa: SIM108 |
| 258 | + grad = jnp.tensordot( |
| 259 | + grad, |
| 260 | + mat, |
| 261 | + axes=((0,), (1,)), |
| 262 | + ) |
| 263 | + else: |
| 264 | + grad = jnp.moveaxis(grad, 0, -1) |
| 265 | + |
| 266 | + return grad |
| 267 | + |
| 268 | + |
| 269 | +def get_orthogonal_matrix(gg: Array) -> Union[Array, None]: |
| 270 | + if gg is None: |
| 271 | + return None |
| 272 | + |
| 273 | + _, eigh = jnp.linalg.eigh(gg + 1e-30 * jnp.eye(gg.shape[0])) |
| 274 | + return jnp.flip(eigh, axis=1) |
| 275 | + |
| 276 | + |
| 277 | +def get_orthogonal_matrix_QR( |
| 278 | + GG: List[Union[Array, None]], |
| 279 | + Q: List[Union[Array, None]], |
| 280 | + exp_avg_sq: Array, |
| 281 | +) -> tuple[List[Union[Array, None]], Array]: |
| 282 | + final_Q = [] |
| 283 | + for ind, (m, o) in enumerate(zip(GG, Q)): |
| 284 | + if m is None or o is None: |
| 285 | + final_Q.append(None) |
| 286 | + continue |
| 287 | + |
| 288 | + est_eig = jnp.diag(o.T @ m @ o) |
| 289 | + sort_idx = jnp.argsort(est_eig, descending=True) |
| 290 | + exp_avg_sq = jnp.take(exp_avg_sq, sort_idx, axis=ind) |
| 291 | + o = o[:, sort_idx] |
| 292 | + power_iter = m @ o |
| 293 | + Q_new, _ = jnp.linalg.qr(power_iter) |
| 294 | + |
| 295 | + final_Q.append(Q_new) |
| 296 | + |
| 297 | + return final_Q, exp_avg_sq |
| 298 | + |
| 299 | + |
| 300 | +def lerp( |
| 301 | + start: Array, |
| 302 | + end: Array, |
| 303 | + weight: Numeric, |
| 304 | +): |
| 305 | + return start + weight * (end - start) |
| 306 | + |
| 307 | + |
| 308 | +def init_conditioner(p: Array, max_precond_dim: int) -> List[Union[Array, None]]: |
| 309 | + if p.ndim == 1: |
| 310 | + return [jnp.zeros((p.shape[0], p.shape[0]))] |
| 311 | + |
| 312 | + return [jnp.zeros((s, s)) if s <= max_precond_dim else None for s in p.shape] |
0 commit comments