Skip to content

Commit 50e043c

Browse files
committed
[nnx] add flax_pytree_module flag
1 parent 8447c03 commit 50e043c

File tree

10 files changed

+213
-87
lines changed

10 files changed

+213
-87
lines changed

examples/nnx_toy_examples/mutable_array_demo.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,15 @@
1313
# limitations under the License.
1414

1515
# %%
16-
import os
17-
18-
os.environ['FLAX_MUTABLE_ARRAY'] = 'true'
19-
2016
import jax
2117
import jax.numpy as jnp
2218
import matplotlib.pyplot as plt
2319
import numpy as np
2420

2521
from flax import nnx
2622

23+
# activate mutable arrays
24+
nnx.use_mutable_arrays(True)
2725

2826
# ## Data
2927
# We create a simple dataset of points sampled from a parabola with some noise.
@@ -151,9 +149,7 @@ def create_block(rngs, /):
151149

152150
self.blocks = nnx.mutable(create_block(rngs.fork(split=num_blocks)))
153151
else:
154-
self.blocks = nnx.data(
155-
[Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]
156-
)
152+
self.blocks = [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]
157153

158154
def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None):
159155
self.count[...] += 1
@@ -197,13 +193,11 @@ def make_opt_state(x):
197193
else:
198194
return OptState(jnp.zeros_like(x))
199195

200-
self.momentum = nnx.data(
201-
jax.tree.map(
196+
self.momentum = jax.tree.map(
202197
make_opt_state,
203198
params,
204199
is_leaf=lambda x: isinstance(x, nnx.Variable),
205200
)
206-
)
207201

208202
# during the update we simply map over (params, momentum, grads),
209203
# for each triplet we implement the SGD update rule which updates
@@ -226,11 +220,13 @@ def update_fn(
226220
# Variables are immutable (only contain Arrays) by default as it can make
227221
# initialization easier, however this means we have to use 'mutable' to
228222
# create the MutableArrays that will be updated during training.
223+
229224
rngs = nnx.Rngs(params=0, dropout=1)
230225
model = Model(
231226
num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs
232227
)
233228
optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99)
229+
234230
# Create a copy of the model structure and set its attributes to eval model.
235231
# This works because they share the underlying MutableArrays so both models
236232
# will always be in sync.
@@ -260,7 +256,6 @@ def loss_fn(params):
260256
# so we don't need to return anything 🚀
261257
optimizer.update(params, grads)
262258

263-
264259
# simple test step that computes the loss
265260
@jax.jit
266261
def test_step(model: Model, x, y):

flax/configurations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
class Config:
2525
flax_use_flaxlib: bool
2626
flax_mutable_array: bool
27+
flax_pytree_module: bool
2728
flax_max_repr_depth: int | None
2829
# See https://google.github.io/pytype/faq.html.
2930
_HAS_DYNAMIC_ATTRIBUTES = True
@@ -272,6 +273,11 @@ def temp_flip_flag(var_name: str, var_value: bool):
272273
default=False,
273274
help='Whether to use mutable arrays.',
274275
)
276+
flax_pytree_module = bool_flag(
277+
name='flax_pytree_module',
278+
default=True,
279+
help='Whether Modules are pytrees by default or not.',
280+
)
275281

276282
flax_max_repr_depth = int_flag(
277283
name='flax_max_repr_depth',

flax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@
180180
from .variablelib import mutable_array as mutable_array
181181
from .variablelib import MutableArray as MutableArray
182182
from .variablelib import is_mutable_array as is_mutable_array
183+
from .variablelib import use_mutable_arrays as use_mutable_arrays
184+
from .variablelib import using_mutable_arrays as using_mutable_arrays
183185
from .visualization import display as display
184186
from .extract import to_tree as to_tree
185187
from .extract import from_tree as from_tree

flax/nnx/graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1612,7 +1612,11 @@ def create_static_cache(x):
16121612
return node_cache
16131613
return x
16141614

1615-
cached_args = jax.tree.map(create_static_cache, cached_args)
1615+
cached_args = jax.tree.map(
1616+
create_static_cache,
1617+
cached_args,
1618+
is_leaf=lambda x: is_graph_node(x) or isinstance(x, Variable),
1619+
)
16161620

16171621
@functools.wraps(f)
16181622
def cache_args_wrapper(*args, **kwargs):

flax/nnx/nn/normalization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax.numpy as jnp
1919
from jax import lax
2020

21-
from flax import nnx, config
21+
from flax import nnx
2222
from flax.nnx import rnglib
2323
from flax.nnx.module import Module, first_from
2424
from flax.nnx.nn import dtypes, initializers
@@ -355,7 +355,7 @@ def __call__(
355355
mask=mask,
356356
)
357357
# stop_gradient only for flax_mutable_array
358-
if config.flax_mutable_array:
358+
if self.mean.mutable or self.var.mutable:
359359
stop_gradient = jax.lax.stop_gradient
360360
else:
361361
stop_gradient = lambda x: x

flax/nnx/object.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
visualization,
3737
)
3838
from flax import config
39-
from flax.nnx.variablelib import Variable, is_mutable_array
39+
from flax.nnx.variablelib import Variable
4040
from flax.typing import SizeBytes
4141

4242
BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ
@@ -104,6 +104,7 @@ def __init__(self):
104104
"""
105105
return DataAttr(value) # type: ignore[return-value]
106106

107+
107108
def register_data_type(type_: type, /) -> None:
108109
"""Registers a type as pytree data type recognized by Object.
109110
@@ -264,6 +265,7 @@ def __treescope_repr__(self, path, subtree_renderer):
264265
subtree_renderer=subtree_renderer,
265266
)
266267

268+
267269
def _flatten_object_state(state: ObjectState):
268270
return (), (state.initializing, state.is_setup)
269271

@@ -279,6 +281,7 @@ def _unflatten_object_state(static: tuple[bool, bool], _):
279281
_unflatten_object_state,
280282
)
281283

284+
282285
class ObjectMeta(ABCMeta):
283286
if not tp.TYPE_CHECKING:
284287

@@ -291,9 +294,18 @@ def _object_meta_construct(cls, self, *args, **kwargs):
291294

292295
def _graph_node_meta_call(cls: tp.Type[O], *args, **kwargs) -> O:
293296
node = cls.__new__(cls, *args, **kwargs)
294-
vars(node)['_object__state'] = ObjectState()
295-
vars(node)['_object__nodes'] = cls._object__nodes
297+
vars_obj = vars(node)
298+
vars_obj['_object__state'] = ObjectState()
299+
vars_obj['_object__nodes'] = cls._object__nodes
296300
cls._object_meta_construct(node, *args, **kwargs)
301+
# register possible new data attributes after initialization
302+
for name, value in vars_obj.items():
303+
if name not in vars_obj['_object__nodes']:
304+
if any(
305+
is_data_type(leaf)
306+
for leaf in jax.tree.leaves(value, is_leaf=is_data_type)
307+
):
308+
vars_obj['_object__nodes'] = vars_obj['_object__nodes'].union((name,))
297309

298310
return node
299311

@@ -312,6 +324,7 @@ def __nnx_repr__(self):
312324
yield reprlib.Attr('shape', self.shape)
313325
yield reprlib.Attr('dtype', self.dtype)
314326

327+
315328
@dataclasses.dataclass(frozen=True, repr=False)
316329
class MutableArrayRepr(reprlib.Representable):
317330
shape: tp.Tuple[int, ...]
@@ -326,6 +339,7 @@ def __nnx_repr__(self):
326339
yield reprlib.Attr('shape', self.shape)
327340
yield reprlib.Attr('dtype', self.dtype)
328341

342+
329343
class Object(reprlib.Representable, metaclass=ObjectMeta):
330344
"""Base class for all NNX objects."""
331345

@@ -335,7 +349,7 @@ class Object(reprlib.Representable, metaclass=ObjectMeta):
335349
_object__state: ObjectState
336350

337351
def __init_subclass__(
338-
cls, *, pytree: bool = config.flax_mutable_array, **kwargs
352+
cls, *, pytree: bool = config.flax_pytree_module, **kwargs
339353
) -> None:
340354
super().__init_subclass__(**kwargs)
341355

@@ -387,20 +401,12 @@ def _setattr(self, name: str, value: tp.Any) -> None:
387401
value = value.value
388402
if name not in self._object__nodes:
389403
self._object__nodes = self._object__nodes.union((name,))
390-
elif is_data_type(value):
391-
if name not in self._object__nodes:
392-
self._object__nodes = self._object__nodes.union((name,))
393-
elif type(self)._object__is_pytree and name not in self._object__nodes:
394-
for leaf in jax.tree.leaves(value):
395-
if isinstance(leaf, jax.Array) or is_mutable_array(leaf):
396-
raise TypeError(
397-
f"Trying to set '{name}' to a value containing one or more "
398-
f"jax.Array, but '{name}' is not a registered as data. "
399-
f"Use 'obj.{name} = nnx.data(...)' to register the attribute as data "
400-
f"on assignment, or add '{name}: nnx.Data[{type(value).__name__}]' "
401-
f'to the class definition. '
402-
f'Got value: {value}'
403-
)
404+
# any attribute that contains known data types will be registered as data
405+
elif name not in self._object__nodes and any(
406+
is_data_type(leaf)
407+
for leaf in jax.tree.leaves(value, is_leaf=is_data_type)
408+
):
409+
self._object__nodes = self._object__nodes.union((name,))
404410
object.__setattr__(self, name, value)
405411

406412
def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:

flax/nnx/transforms/iteration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def check_carry_same_references(key_path, arg, out):
652652
)
653653

654654
jax.tree_util.tree_map_with_path(
655-
check_carry_same_references, carry_arg, carry_arg_out
655+
check_carry_same_references, carry_arg, carry_arg_out, is_leaf=graph.is_graph_node
656656
)
657657

658658
def _extract_graphdefs(

0 commit comments

Comments
 (0)