1313# limitations under the License.
1414
1515# %%
16- import os
17-
18- os .environ ['FLAX_MUTABLE_ARRAY' ] = 'true'
19-
2016import jax
2117import jax .numpy as jnp
2218import matplotlib .pyplot as plt
2319import numpy as np
2420
2521from flax import nnx
2622
23+ # activate mutable arrays
24+ nnx .use_mutable_arrays (True )
2725
2826# ## Data
2927# We create a simple dataset of points sampled from a parabola with some noise.
@@ -151,9 +149,7 @@ def create_block(rngs, /):
151149
152150 self .blocks = nnx .mutable (create_block (rngs .fork (split = num_blocks )))
153151 else :
154- self .blocks = nnx .data (
155- [Block (dhidden , dhidden , rngs = rngs ) for i in range (num_blocks )]
156- )
152+ self .blocks = [Block (dhidden , dhidden , rngs = rngs ) for i in range (num_blocks )]
157153
158154 def __call__ (self , x : jax .Array , * , rngs : nnx .Rngs | None = None ):
159155 self .count [...] += 1
@@ -197,13 +193,11 @@ def make_opt_state(x):
197193 else :
198194 return OptState (jnp .zeros_like (x ))
199195
200- self .momentum = nnx .data (
201- jax .tree .map (
196+ self .momentum = jax .tree .map (
202197 make_opt_state ,
203198 params ,
204199 is_leaf = lambda x : isinstance (x , nnx .Variable ),
205200 )
206- )
207201
208202 # during the update we simply map over (params, momentum, grads),
209203 # for each triplet we implement the SGD update rule which updates
@@ -226,11 +220,13 @@ def update_fn(
226220# Variables are immutable (only contain Arrays) by default as it can make
227221# initialization easier, however this means we have to use 'mutable' to
228222# create the MutableArrays that will be updated during training.
223+
229224rngs = nnx .Rngs (params = 0 , dropout = 1 )
230225model = Model (
231226 num_blocks = 3 , din = 1 , dhidden = 256 , dout = 1 , use_scan = False , rngs = rngs
232227)
233228optimizer = SGD (params = nnx .state (model , nnx .Param ), lr = 3e-3 , decay = 0.99 )
229+
234230# Create a copy of the model structure and set its attributes to eval model.
235231# This works because they share the underlying MutableArrays so both models
236232# will always be in sync.
@@ -260,7 +256,6 @@ def loss_fn(params):
260256 # so we don't need to return anything 🚀
261257 optimizer .update (params , grads )
262258
263-
264259# simple test step that computes the loss
265260@jax .jit
266261def test_step (model : Model , x , y ):
0 commit comments