1717
1818os .environ ['FLAX_MUTABLE_ARRAY' ] = 'true'
1919
20- from typing import Any , TypeVar
21- from collections .abc import Mapping
2220import jax
2321import jax .numpy as jnp
2422import matplotlib .pyplot as plt
2523import numpy as np
26- from jax ._src .core import MutableArray
2724
2825from flax import nnx
2926
3027
31- # # Utils
32- A = TypeVar ('A' )
33-
34- def mutable_like (path , x ):
35- return (isinstance (x , nnx .Variable ) and x .mutable ) or nnx .is_mutable_array (x )
36-
37-
38- def freeze (x : A , only : nnx .filterlib .Filter = mutable_like ) -> A :
39- freeze_filter = nnx .filterlib .to_predicate (only )
40- mutable_arrays : set [int ] = set ()
41-
42- def check_mutable_array (path , x ):
43- m_array_id = id (x )
44- if m_array_id in mutable_arrays :
45- path_str = jax .tree_util .keystr (path )
46- raise ValueError (
47- f'Found duplicate MutableArray found at path { path_str } : { x } '
48- )
49- mutable_arrays .add (m_array_id )
50-
51- def _freeze_fn (jax_path , x ):
52- path = tuple (nnx .graph ._key_path_to_key (part ) for part in jax_path )
53- if freeze_filter (path , x ):
54- if isinstance (x , nnx .Variable ):
55- check_mutable_array (jax_path , x .raw_value )
56- return x .from_metadata (x [...], x .get_metadata ().copy ())
57- elif nnx .is_mutable_array (x ):
58- check_mutable_array (jax_path , x )
59- return x [...]
60- return x
61-
62- return jax .tree .map_with_path (
63- _freeze_fn , x , is_leaf = lambda x : isinstance (x , nnx .Variable )
64- )
65-
66-
67- def array_like (path , x ):
68- return (
69- isinstance (x , nnx .Variable ) and not x .mutable
70- ) or nnx .is_mutable_array (x )
71-
72-
73- def mutable (x : A , only : nnx .filterlib .Filter = array_like ) -> A :
74- freeze_filter = nnx .filterlib .to_predicate (only )
75- mutable_arrays : dict [int , Any ] = {}
76-
77- def get_mutable (x ):
78- m_array_id = id (x )
79- if m_array_id in mutable_arrays :
80- return mutable_arrays [m_array_id ]
81-
82- if isinstance (x , nnx .Variable ):
83- assert not x .mutable
84- _mutable = x .from_metadata (
85- nnx .mutable_array (x .raw_value ),
86- x .get_metadata ().copy (),
87- )
88- mutable_arrays [m_array_id ] = _mutable
89- return _mutable
90- elif isinstance (x , jax .Array ):
91- _mutable = nnx .mutable_array (x )
92- mutable_arrays [m_array_id ] = _mutable
93- return _mutable
94- return x
95-
96- def _mutable_fn (jax_path , x ):
97- path = tuple (nnx .graph ._key_path_to_key (part ) for part in jax_path )
98- if freeze_filter (path , x ):
99- return get_mutable (x )
100- return x
101-
102- return jax .tree .map_with_path (
103- _mutable_fn , x , is_leaf = lambda x : isinstance (x , nnx .Variable )
104- )
105-
106- def pure (tree : A ) -> A :
107- def _pure_fn (x ):
108- if isinstance (x , nnx .Variable | nnx .VariableState ):
109- return x .raw_value
110- return x
111-
112- return jax .tree .map (
113- _pure_fn , tree , is_leaf = lambda x : isinstance (x , nnx .Variable | nnx .VariableState )
114- )
115-
116- def fork_rngs (
117- rngs : nnx .Rngs ,
118- / ,
119- * ,
120- split : Mapping [nnx .filterlib .Filter , int | tuple [int , ...]] | int | None = None ,
121- ):
122- if split is None :
123- split = {}
124- elif isinstance (split , int ):
125- split = {...: split }
126-
127- split_predicates = {
128- nnx .filterlib .to_predicate (k ): v for k , v in split .items ()
129- }
130- keys : dict [str , jax .Array ] = {}
131- for name , stream in rngs .items ():
132- for predicate , num_splits in split_predicates .items ():
133- if predicate ((), stream ):
134- keys [name ] = jax .random .split (stream (), num_splits )
135- break
136- else :
137- keys [name ] = stream ()
138-
139- return nnx .Rngs (** keys )
140-
141-
142- def fork_stream (stream : nnx .RngStream ):
143- key = stream ()
144- return type (stream )(stream .tag , key )
145-
14628# ## Data
14729# We create a simple dataset of points sampled from a parabola with some noise.
14830X = np .linspace (- jnp .pi , jnp .pi , 100 )[:, None ]
@@ -221,11 +103,10 @@ def __init__(
221103 # ----------- dropout ------------------
222104 self .dropout_rate = dropout_rate
223105 self .deterministic = deterministic
224- # 'fork' is used to get a derived frozen stream, this is done
225- # to avoid aliasing MutableArray as as its not supported by JAX
226- self .rng = fork_stream (rngs .dropout )
227106
228- def __call__ (self , x : jax .Array ) -> jax .Array :
107+ def __call__ (
108+ self , x : jax .Array , * , rngs : nnx .Rngs | None = None
109+ ) -> jax .Array :
229110 # ----------- linear --------------------
230111 x = x @ self .w [...] + self .b [None ]
231112 # ----------- batch norm ----------------
@@ -244,21 +125,15 @@ def __call__(self, x: jax.Array) -> jax.Array:
244125 x = x * self .scale [...] + self .bias [...]
245126 # ----------- dropout -------------------
246127 if not self .deterministic and self .dropout_rate > 0.0 :
128+ assert rngs is not None
247129 keep_prob = 1.0 - self .dropout_rate
248- mask = jax .random .bernoulli (self . rng (), keep_prob , x .shape )
130+ mask = jax .random .bernoulli (rngs . dropout (), keep_prob , x .shape )
249131 x = jnp .where (mask , x / keep_prob , jnp .zeros_like (x ))
250132 # ----------- activation ---------------
251133 x = jax .nn .gelu (x )
252134 return x
253135
254136
255- # Trivial Variables subclasses are used to easily
256- # query state groups. In this case Count will be used to hold
257- # non-differentiable state containing the number of times the model
258- # is called.
259- class Count (nnx .Variable ): ...
260-
261-
262137class Model (nnx .Module ):
263138 __data__ = ('block_in' , 'blocks' , 'linear_out' , 'count' )
264139
@@ -272,7 +147,7 @@ def __init__(
272147 use_scan : bool = True ,
273148 rngs : nnx .Rngs ,
274149 ):
275- self .count = Count (jnp .array (0 ))
150+ self .count : nnx . MutableArray = nnx . mutable_array (jnp .array (0 ))
276151 self .block_in = Block (din , dhidden , rngs = rngs )
277152 self .linear_out = Linear (dhidden , dout , rngs = rngs )
278153
@@ -283,29 +158,29 @@ def __init__(
283158
284159 @jax .vmap
285160 def create_block (rngs , / ):
286- return freeze (Block (dhidden , dhidden , rngs = rngs ))
161+ return nnx . freeze (Block (dhidden , dhidden , rngs = rngs ))
287162
288- self .blocks = mutable (create_block (fork_rngs ( rngs , split = num_blocks )))
163+ self .blocks = nnx . mutable (create_block (rngs . fork ( split = num_blocks )))
289164 else :
290165 self .blocks = [
291166 Block (dhidden , dhidden , rngs = rngs ) for i in range (num_blocks )
292167 ]
293168
294- def __call__ (self , x : jax .Array ):
169+ def __call__ (self , x : jax .Array , * , rngs : nnx . Rngs | None = None ):
295170 self .count [...] += 1
296- x = self .block_in (x )
171+ x = self .block_in (x , rngs = rngs )
297172
298173 # on the forward pass we either iterate over the block
299174 # list or use jax.lax.scan to apply the blocks, if we
300175 # had shared state we would use split and merge to
301176 # pass the shared state as a capture
302177 if isinstance (self .blocks , list ):
303178 for block in self .blocks :
304- x = block (x )
179+ x = block (x , rngs = rngs )
305180 else :
306181
307182 def block_fw (x , block : Block ):
308- x = block (x )
183+ x = block (x , rngs = rngs )
309184 return x , None
310185
311186 x , _ = jax .lax .scan (block_fw , x , self .blocks )
@@ -314,8 +189,6 @@ def block_fw(x, block: Block):
314189
315190
316191# ## Optimizer
317-
318-
319192class OptState (nnx .Variable ): ...
320193
321194
@@ -338,18 +211,22 @@ def make_opt_state(x):
338211 return OptState (jnp .zeros_like (x ))
339212
340213 self .momentum = jax .tree .map (
341- make_opt_state , params , is_leaf = lambda x : isinstance (x , nnx .Variable | nnx .VariableState )
214+ make_opt_state ,
215+ params ,
216+ is_leaf = lambda x : isinstance (x , nnx .Variable | nnx .VariableState ),
342217 )
343218
344219 # during the update we simply map over (params, momentum, grads),
345220 # for each triplet we implement the SGD update rule which updates
346221 # both the optimizer's state (momentum) and the params in place.
347222 def update (self , params , grads ):
348- params = pure (params )
349- grads = pure (grads )
350- momentum = pure (self .momentum )
223+ params = nnx . pure (params )
224+ grads = nnx . pure (grads )
225+ momentum = nnx . pure (self .momentum )
351226
352- def update_fn (param : MutableArray , momentum : MutableArray , grad : jax .Array ):
227+ def update_fn (
228+ param : nnx .MutableArray , momentum : nnx .MutableArray , grad : jax .Array
229+ ):
353230 momentum [...] = self .decay * momentum [...] + (1 - self .decay ) * grad [...]
354231 param [...] -= self .lr * momentum [...]
355232
@@ -362,32 +239,36 @@ def update_fn(param: MutableArray, momentum: MutableArray, grad: jax.Array):
362239# initialization easier, however this means we have to use 'mutable' to
363240# create the MutableArrays that will be updated during training.
364241
365-
242+ rngs = nnx . Rngs ( params = 0 , dropout = 1 )
366243model = Model (
367- num_blocks = 3 , din = 1 , dhidden = 256 , dout = 1 , use_scan = False , rngs = nnx . Rngs ( 0 )
244+ num_blocks = 3 , din = 1 , dhidden = 256 , dout = 1 , use_scan = False , rngs = rngs
368245)
369246optimizer = SGD (params = nnx .state (model , nnx .Param ), lr = 3e-3 , decay = 0.99 )
247+ # Create a copy of the model structure and set its attributes to eval model.
248+ # This works because they share the underlying MutableArrays so both models
249+ # will always be in sync.
370250eval_model = nnx .merge (* nnx .split (model ))
371251eval_model .set_attributes (use_stats = True , deterministic = True )
372252
253+
373254# The training step uses 'jax.jit' and receives the model and optimizer as arguments,
374255# this is supported as they are now pytrees. The first thing we do is group the model
375256# state into the params and the non-differentiable state using 'split'. We differentiate
376257# the loss function using 'jax.grad' with respect to the params-only. Inside the loss
377258# function we merge the params and non-diff state back into a single model and then
378259# compute the loss by calling the model with the inputs.
379260@jax .jit
380- def train_step (model : Model , optimizer : SGD , x , y ):
261+ def train_step (model : Model , optimizer : SGD , rngs : nnx . Rngs , x , y ):
381262 treedef , params , nondiff = nnx .split (model , nnx .Param , ...)
382263
383264 def loss_fn (params ):
384265 model = nnx .merge (treedef , params , nondiff )
385- loss = jnp .mean ((model (x ) - y ) ** 2 )
266+ loss = jnp .mean ((model (x , rngs = rngs ) - y ) ** 2 )
386267 return loss
387268
388269 # For the time being we have to use 'freeze' make the Variables immutable
389270 # as 'jax.grad' doesn't support MutableArrays yet.
390- grads = jax .grad (loss_fn )(freeze (params ))
271+ grads = jax .grad (loss_fn )(nnx . freeze (params ))
391272 # 'update' mutates the optimizer's state and the params in place
392273 # so we don't need to return anything 🚀
393274 optimizer .update (params , grads )
@@ -402,7 +283,7 @@ def test_step(model: Model, x, y):
402283# minimalistic training loop
403284total_steps = 10_000
404285for step , (x , y ) in enumerate (dataset (32 )):
405- train_step (model , optimizer , x , y )
286+ train_step (model , optimizer , rngs , x , y )
406287
407288 if step % 1000 == 0 :
408289 logs = test_step (eval_model , X , Y )
0 commit comments