Skip to content

Commit e264671

Browse files
committed
move freeze, mutable, and pure to graph.py
1 parent 866fe95 commit e264671

28 files changed

+1478
-596
lines changed

examples/nnx_toy_examples/01_functional_api.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
from flax import nnx
2222

23-
X = np.linspace(0, 1, 100)[:, None]
24-
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
23+
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
24+
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
2525

2626

2727
def dataset(batch_size):
@@ -50,11 +50,8 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
5050
self.linear2 = Linear(dhidden, dout, rngs=rngs)
5151

5252
def __call__(self, x):
53-
self.count.value += 1
54-
x = self.linear1(x)
55-
x = jax.nn.relu(x)
56-
x = self.linear2(x)
57-
return x
53+
self.count[...] += 1
54+
return self.linear2(jax.nn.relu(self.linear1(x) * 0.5))
5855

5956

6057
graphdef, params, counts = nnx.split(

examples/nnx_toy_examples/mutable_array_basic.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,7 @@
2323
import numpy as np
2424

2525
from flax import nnx
26-
from flax.nnx.variablelib import is_mutable_array
27-
28-
29-
def mutable_like(path, x):
30-
return (
31-
isinstance(x, nnx.Variable) and x.mutable
32-
) or nnx.variablelib.is_mutable_array(x)
33-
34-
35-
def freeze(x, only: nnx.filterlib.Filter = mutable_like):
36-
freeze_filter = nnx.filterlib.to_predicate(only)
37-
mutable_arrays: set[int] = set()
38-
39-
def check_mutable_array(path, x):
40-
m_array_id = id(x)
41-
if m_array_id in mutable_arrays:
42-
path_str = jax.tree_util.keystr(path)
43-
raise ValueError(
44-
f'Found duplicate MutableArray found at path {path_str}: {x}'
45-
)
46-
mutable_arrays.add(m_array_id)
47-
48-
def _freeze_fn(jax_path, x):
49-
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
50-
if freeze_filter(path, x):
51-
if isinstance(x, nnx.Variable):
52-
check_mutable_array(jax_path, x.raw_value)
53-
return x.from_metadata(x[...], x.get_metadata().copy())
54-
elif nnx.variablelib.is_mutable_array(x):
55-
check_mutable_array(jax_path, x)
56-
return x[...]
57-
return x
58-
59-
return jax.tree.map_with_path(
60-
_freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
61-
)
26+
6227

6328
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
6429
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
@@ -94,21 +59,21 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
9459

9560
def __call__(self, x):
9661
self.count[...] += 1
97-
return self.linear2(jax.nn.gelu(self.linear1(x)) * 0.5)
62+
return self.linear2(jax.nn.relu(self.linear1(x)) * 0.5)
9863

9964

100-
model = MLP(din=1, dhidden=64, dout=1, rngs=nnx.Rngs(0))
65+
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
10166

10267

10368
@jax.jit
10469
def train_step(model, x, y):
105-
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
70+
graphdef, params, counts = nnx.pure(nnx.split(model, nnx.Param, Count))
10671

10772
def loss_fn(params):
10873
model = nnx.merge(graphdef, params, counts)
10974
return jnp.mean((y - model(x)) ** 2)
11075

111-
grads = jax.grad(loss_fn)(freeze(params))
76+
grads = jax.grad(loss_fn)(nnx.freeze(params))
11277

11378
def sgd(w, g):
11479
w[...] -= 0.1 * g[...]

examples/nnx_toy_examples/mutable_array_demo.py

Lines changed: 31 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -17,132 +17,14 @@
1717

1818
os.environ['FLAX_MUTABLE_ARRAY'] = 'true'
1919

20-
from typing import Any, TypeVar
21-
from collections.abc import Mapping
2220
import jax
2321
import jax.numpy as jnp
2422
import matplotlib.pyplot as plt
2523
import numpy as np
26-
from jax._src.core import MutableArray
2724

2825
from flax import nnx
2926

3027

31-
# # Utils
32-
A = TypeVar('A')
33-
34-
def mutable_like(path, x):
35-
return (isinstance(x, nnx.Variable) and x.mutable) or nnx.is_mutable_array(x)
36-
37-
38-
def freeze(x: A, only: nnx.filterlib.Filter = mutable_like) -> A:
39-
freeze_filter = nnx.filterlib.to_predicate(only)
40-
mutable_arrays: set[int] = set()
41-
42-
def check_mutable_array(path, x):
43-
m_array_id = id(x)
44-
if m_array_id in mutable_arrays:
45-
path_str = jax.tree_util.keystr(path)
46-
raise ValueError(
47-
f'Found duplicate MutableArray found at path {path_str}: {x}'
48-
)
49-
mutable_arrays.add(m_array_id)
50-
51-
def _freeze_fn(jax_path, x):
52-
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
53-
if freeze_filter(path, x):
54-
if isinstance(x, nnx.Variable):
55-
check_mutable_array(jax_path, x.raw_value)
56-
return x.from_metadata(x[...], x.get_metadata().copy())
57-
elif nnx.is_mutable_array(x):
58-
check_mutable_array(jax_path, x)
59-
return x[...]
60-
return x
61-
62-
return jax.tree.map_with_path(
63-
_freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
64-
)
65-
66-
67-
def array_like(path, x):
68-
return (
69-
isinstance(x, nnx.Variable) and not x.mutable
70-
) or nnx.is_mutable_array(x)
71-
72-
73-
def mutable(x: A, only: nnx.filterlib.Filter = array_like) -> A:
74-
freeze_filter = nnx.filterlib.to_predicate(only)
75-
mutable_arrays: dict[int, Any] = {}
76-
77-
def get_mutable(x):
78-
m_array_id = id(x)
79-
if m_array_id in mutable_arrays:
80-
return mutable_arrays[m_array_id]
81-
82-
if isinstance(x, nnx.Variable):
83-
assert not x.mutable
84-
_mutable = x.from_metadata(
85-
nnx.mutable_array(x.raw_value),
86-
x.get_metadata().copy(),
87-
)
88-
mutable_arrays[m_array_id] = _mutable
89-
return _mutable
90-
elif isinstance(x, jax.Array):
91-
_mutable = nnx.mutable_array(x)
92-
mutable_arrays[m_array_id] = _mutable
93-
return _mutable
94-
return x
95-
96-
def _mutable_fn(jax_path, x):
97-
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
98-
if freeze_filter(path, x):
99-
return get_mutable(x)
100-
return x
101-
102-
return jax.tree.map_with_path(
103-
_mutable_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
104-
)
105-
106-
def pure(tree: A) -> A:
107-
def _pure_fn(x):
108-
if isinstance(x, nnx.Variable | nnx.VariableState):
109-
return x.raw_value
110-
return x
111-
112-
return jax.tree.map(
113-
_pure_fn, tree, is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState)
114-
)
115-
116-
def fork_rngs(
117-
rngs: nnx.Rngs,
118-
/,
119-
*,
120-
split: Mapping[nnx.filterlib.Filter, int | tuple[int, ...]] | int | None = None,
121-
):
122-
if split is None:
123-
split = {}
124-
elif isinstance(split, int):
125-
split = {...: split}
126-
127-
split_predicates = {
128-
nnx.filterlib.to_predicate(k): v for k, v in split.items()
129-
}
130-
keys: dict[str, jax.Array] = {}
131-
for name, stream in rngs.items():
132-
for predicate, num_splits in split_predicates.items():
133-
if predicate((), stream):
134-
keys[name] = jax.random.split(stream(), num_splits)
135-
break
136-
else:
137-
keys[name] = stream()
138-
139-
return nnx.Rngs(**keys)
140-
141-
142-
def fork_stream(stream: nnx.RngStream):
143-
key = stream()
144-
return type(stream)(stream.tag, key)
145-
14628
# ## Data
14729
# We create a simple dataset of points sampled from a parabola with some noise.
14830
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
@@ -221,11 +103,10 @@ def __init__(
221103
# ----------- dropout ------------------
222104
self.dropout_rate = dropout_rate
223105
self.deterministic = deterministic
224-
# 'fork' is used to get a derived frozen stream, this is done
225-
# to avoid aliasing MutableArray as as its not supported by JAX
226-
self.rng = fork_stream(rngs.dropout)
227106

228-
def __call__(self, x: jax.Array) -> jax.Array:
107+
def __call__(
108+
self, x: jax.Array, *, rngs: nnx.Rngs | None = None
109+
) -> jax.Array:
229110
# ----------- linear --------------------
230111
x = x @ self.w[...] + self.b[None]
231112
# ----------- batch norm ----------------
@@ -244,21 +125,15 @@ def __call__(self, x: jax.Array) -> jax.Array:
244125
x = x * self.scale[...] + self.bias[...]
245126
# ----------- dropout -------------------
246127
if not self.deterministic and self.dropout_rate > 0.0:
128+
assert rngs is not None
247129
keep_prob = 1.0 - self.dropout_rate
248-
mask = jax.random.bernoulli(self.rng(), keep_prob, x.shape)
130+
mask = jax.random.bernoulli(rngs.dropout(), keep_prob, x.shape)
249131
x = jnp.where(mask, x / keep_prob, jnp.zeros_like(x))
250132
# ----------- activation ---------------
251133
x = jax.nn.gelu(x)
252134
return x
253135

254136

255-
# Trivial Variables subclasses are used to easily
256-
# query state groups. In this case Count will be used to hold
257-
# non-differentiable state containing the number of times the model
258-
# is called.
259-
class Count(nnx.Variable): ...
260-
261-
262137
class Model(nnx.Module):
263138
__data__ = ('block_in', 'blocks', 'linear_out', 'count')
264139

@@ -272,7 +147,7 @@ def __init__(
272147
use_scan: bool = True,
273148
rngs: nnx.Rngs,
274149
):
275-
self.count = Count(jnp.array(0))
150+
self.count: nnx.MutableArray = nnx.mutable_array(jnp.array(0))
276151
self.block_in = Block(din, dhidden, rngs=rngs)
277152
self.linear_out = Linear(dhidden, dout, rngs=rngs)
278153

@@ -283,29 +158,29 @@ def __init__(
283158

284159
@jax.vmap
285160
def create_block(rngs, /):
286-
return freeze(Block(dhidden, dhidden, rngs=rngs))
161+
return nnx.freeze(Block(dhidden, dhidden, rngs=rngs))
287162

288-
self.blocks = mutable(create_block(fork_rngs(rngs, split=num_blocks)))
163+
self.blocks = nnx.mutable(create_block(rngs.fork(split=num_blocks)))
289164
else:
290165
self.blocks = [
291166
Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)
292167
]
293168

294-
def __call__(self, x: jax.Array):
169+
def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None):
295170
self.count[...] += 1
296-
x = self.block_in(x)
171+
x = self.block_in(x, rngs=rngs)
297172

298173
# on the forward pass we either iterate over the block
299174
# list or use jax.lax.scan to apply the blocks, if we
300175
# had shared state we would use split and merge to
301176
# pass the shared state as a capture
302177
if isinstance(self.blocks, list):
303178
for block in self.blocks:
304-
x = block(x)
179+
x = block(x, rngs=rngs)
305180
else:
306181

307182
def block_fw(x, block: Block):
308-
x = block(x)
183+
x = block(x, rngs=rngs)
309184
return x, None
310185

311186
x, _ = jax.lax.scan(block_fw, x, self.blocks)
@@ -314,8 +189,6 @@ def block_fw(x, block: Block):
314189

315190

316191
# ## Optimizer
317-
318-
319192
class OptState(nnx.Variable): ...
320193

321194

@@ -338,18 +211,22 @@ def make_opt_state(x):
338211
return OptState(jnp.zeros_like(x))
339212

340213
self.momentum = jax.tree.map(
341-
make_opt_state, params, is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState)
214+
make_opt_state,
215+
params,
216+
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
342217
)
343218

344219
# during the update we simply map over (params, momentum, grads),
345220
# for each triplet we implement the SGD update rule which updates
346221
# both the optimizer's state (momentum) and the params in place.
347222
def update(self, params, grads):
348-
params = pure(params)
349-
grads = pure(grads)
350-
momentum = pure(self.momentum)
223+
params = nnx.pure(params)
224+
grads = nnx.pure(grads)
225+
momentum = nnx.pure(self.momentum)
351226

352-
def update_fn(param: MutableArray, momentum: MutableArray, grad: jax.Array):
227+
def update_fn(
228+
param: nnx.MutableArray, momentum: nnx.MutableArray, grad: jax.Array
229+
):
353230
momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...]
354231
param[...] -= self.lr * momentum[...]
355232

@@ -362,32 +239,36 @@ def update_fn(param: MutableArray, momentum: MutableArray, grad: jax.Array):
362239
# initialization easier, however this means we have to use 'mutable' to
363240
# create the MutableArrays that will be updated during training.
364241

365-
242+
rngs = nnx.Rngs(params=0, dropout=1)
366243
model = Model(
367-
num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=nnx.Rngs(0)
244+
num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs
368245
)
369246
optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99)
247+
# Create a copy of the model structure and set its attributes to eval model.
248+
# This works because they share the underlying MutableArrays so both models
249+
# will always be in sync.
370250
eval_model = nnx.merge(*nnx.split(model))
371251
eval_model.set_attributes(use_stats=True, deterministic=True)
372252

253+
373254
# The training step uses 'jax.jit' and receives the model and optimizer as arguments,
374255
# this is supported as they are now pytrees. The first thing we do is group the model
375256
# state into the params and the non-differentiable state using 'split'. We differentiate
376257
# the loss function using 'jax.grad' with respect to the params-only. Inside the loss
377258
# function we merge the params and non-diff state back into a single model and then
378259
# compute the loss by calling the model with the inputs.
379260
@jax.jit
380-
def train_step(model: Model, optimizer: SGD, x, y):
261+
def train_step(model: Model, optimizer: SGD, rngs: nnx.Rngs, x, y):
381262
treedef, params, nondiff = nnx.split(model, nnx.Param, ...)
382263

383264
def loss_fn(params):
384265
model = nnx.merge(treedef, params, nondiff)
385-
loss = jnp.mean((model(x) - y) ** 2)
266+
loss = jnp.mean((model(x, rngs=rngs) - y) ** 2)
386267
return loss
387268

388269
# For the time being we have to use 'freeze' make the Variables immutable
389270
# as 'jax.grad' doesn't support MutableArrays yet.
390-
grads = jax.grad(loss_fn)(freeze(params))
271+
grads = jax.grad(loss_fn)(nnx.freeze(params))
391272
# 'update' mutates the optimizer's state and the params in place
392273
# so we don't need to return anything 🚀
393274
optimizer.update(params, grads)
@@ -402,7 +283,7 @@ def test_step(model: Model, x, y):
402283
# minimalistic training loop
403284
total_steps = 10_000
404285
for step, (x, y) in enumerate(dataset(32)):
405-
train_step(model, optimizer, x, y)
286+
train_step(model, optimizer, rngs, x, y)
406287

407288
if step % 1000 == 0:
408289
logs = test_step(eval_model, X, Y)

0 commit comments

Comments
 (0)