Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions examples/nnx_toy_examples/mutable_array_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@
# limitations under the License.

# %%
import os

os.environ['FLAX_MUTABLE_ARRAY'] = 'true'

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from flax import nnx

# activate mutable arrays
nnx.use_mutable_arrays(True)

# ## Data
# We create a simple dataset of points sampled from a parabola with some noise.
Expand Down Expand Up @@ -151,9 +149,7 @@ def create_block(rngs, /):

self.blocks = nnx.mutable(create_block(rngs.fork(split=num_blocks)))
else:
self.blocks = nnx.data(
[Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]
)
self.blocks = [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]

def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None):
self.count[...] += 1
Expand Down Expand Up @@ -197,13 +193,11 @@ def make_opt_state(x):
else:
return OptState(jnp.zeros_like(x))

self.momentum = nnx.data(
jax.tree.map(
self.momentum = jax.tree.map(
make_opt_state,
params,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
)

# during the update we simply map over (params, momentum, grads),
# for each triplet we implement the SGD update rule which updates
Expand All @@ -226,11 +220,13 @@ def update_fn(
# Variables are immutable (only contain Arrays) by default as it can make
# initialization easier, however this means we have to use 'mutable' to
# create the MutableArrays that will be updated during training.

rngs = nnx.Rngs(params=0, dropout=1)
model = Model(
num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs
)
optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99)

# Create a copy of the model structure and set its attributes to eval model.
# This works because they share the underlying MutableArrays so both models
# will always be in sync.
Expand Down Expand Up @@ -260,7 +256,6 @@ def loss_fn(params):
# so we don't need to return anything 🚀
optimizer.update(params, grads)


# simple test step that computes the loss
@jax.jit
def test_step(model: Model, x, y):
Expand Down
6 changes: 6 additions & 0 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class Config:
flax_use_flaxlib: bool
flax_mutable_array: bool
flax_pytree_module: bool
flax_max_repr_depth: int | None
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True
Expand Down Expand Up @@ -272,6 +273,11 @@ def temp_flip_flag(var_name: str, var_value: bool):
default=False,
help='Whether to use mutable arrays.',
)
flax_pytree_module = bool_flag(
name='flax_pytree_module',
default=True,
help='Whether Modules are pytrees by default or not.',
)

flax_max_repr_depth = int_flag(
name='flax_max_repr_depth',
Expand Down
2 changes: 2 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@
from .variablelib import mutable_array as mutable_array
from .variablelib import MutableArray as MutableArray
from .variablelib import is_mutable_array as is_mutable_array
from .variablelib import use_mutable_arrays as use_mutable_arrays
from .variablelib import using_mutable_arrays as using_mutable_arrays
from .visualization import display as display
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
Expand Down
6 changes: 5 additions & 1 deletion flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,7 +1612,11 @@ def create_static_cache(x):
return node_cache
return x

cached_args = jax.tree.map(create_static_cache, cached_args)
cached_args = jax.tree.map(
create_static_cache,
cached_args,
is_leaf=lambda x: is_graph_node(x) or isinstance(x, Variable),
)

@functools.wraps(f)
def cache_args_wrapper(*args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp
from jax import lax

from flax import nnx, config
from flax import nnx
from flax.nnx import rnglib
from flax.nnx.module import Module, first_from
from flax.nnx.nn import dtypes, initializers
Expand Down Expand Up @@ -355,7 +355,7 @@ def __call__(
mask=mask,
)
# stop_gradient only for flax_mutable_array
if config.flax_mutable_array:
if self.mean.mutable or self.var.mutable:
stop_gradient = jax.lax.stop_gradient
else:
stop_gradient = lambda x: x
Expand Down
42 changes: 24 additions & 18 deletions flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
visualization,
)
from flax import config
from flax.nnx.variablelib import Variable, is_mutable_array
from flax.nnx.variablelib import Variable
from flax.typing import SizeBytes

BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ
Expand Down Expand Up @@ -104,6 +104,7 @@ def __init__(self):
"""
return DataAttr(value) # type: ignore[return-value]


def register_data_type(type_: type, /) -> None:
"""Registers a type as pytree data type recognized by Object.

Expand Down Expand Up @@ -264,6 +265,7 @@ def __treescope_repr__(self, path, subtree_renderer):
subtree_renderer=subtree_renderer,
)


def _flatten_object_state(state: ObjectState):
return (), (state.initializing, state.is_setup)

Expand All @@ -279,6 +281,7 @@ def _unflatten_object_state(static: tuple[bool, bool], _):
_unflatten_object_state,
)


class ObjectMeta(ABCMeta):
if not tp.TYPE_CHECKING:

Expand All @@ -291,9 +294,18 @@ def _object_meta_construct(cls, self, *args, **kwargs):

def _graph_node_meta_call(cls: tp.Type[O], *args, **kwargs) -> O:
node = cls.__new__(cls, *args, **kwargs)
vars(node)['_object__state'] = ObjectState()
vars(node)['_object__nodes'] = cls._object__nodes
vars_obj = vars(node)
vars_obj['_object__state'] = ObjectState()
vars_obj['_object__nodes'] = cls._object__nodes
cls._object_meta_construct(node, *args, **kwargs)
# register possible new data attributes after initialization
for name, value in vars_obj.items():
if name not in vars_obj['_object__nodes']:
if any(
is_data_type(leaf)
for leaf in jax.tree.leaves(value, is_leaf=is_data_type)
):
vars_obj['_object__nodes'] = vars_obj['_object__nodes'].union((name,))

return node

Expand All @@ -312,6 +324,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('shape', self.shape)
yield reprlib.Attr('dtype', self.dtype)


@dataclasses.dataclass(frozen=True, repr=False)
class MutableArrayRepr(reprlib.Representable):
shape: tp.Tuple[int, ...]
Expand All @@ -326,6 +339,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('shape', self.shape)
yield reprlib.Attr('dtype', self.dtype)


class Object(reprlib.Representable, metaclass=ObjectMeta):
"""Base class for all NNX objects."""

Expand All @@ -335,7 +349,7 @@ class Object(reprlib.Representable, metaclass=ObjectMeta):
_object__state: ObjectState

def __init_subclass__(
cls, *, pytree: bool = config.flax_mutable_array, **kwargs
cls, *, pytree: bool = config.flax_pytree_module, **kwargs
) -> None:
super().__init_subclass__(**kwargs)

Expand Down Expand Up @@ -387,20 +401,12 @@ def _setattr(self, name: str, value: tp.Any) -> None:
value = value.value
if name not in self._object__nodes:
self._object__nodes = self._object__nodes.union((name,))
elif is_data_type(value):
if name not in self._object__nodes:
self._object__nodes = self._object__nodes.union((name,))
elif type(self)._object__is_pytree and name not in self._object__nodes:
for leaf in jax.tree.leaves(value):
if isinstance(leaf, jax.Array) or is_mutable_array(leaf):
raise TypeError(
f"Trying to set '{name}' to a value containing one or more "
f"jax.Array, but '{name}' is not a registered as data. "
f"Use 'obj.{name} = nnx.data(...)' to register the attribute as data "
f"on assignment, or add '{name}: nnx.Data[{type(value).__name__}]' "
f'to the class definition. '
f'Got value: {value}'
)
# any attribute that contains known data types will be registered as data
elif name not in self._object__nodes and any(
is_data_type(leaf)
for leaf in jax.tree.leaves(value, is_leaf=is_data_type)
):
self._object__nodes = self._object__nodes.union((name,))
object.__setattr__(self, name, value)

def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def check_carry_same_references(key_path, arg, out):
)

jax.tree_util.tree_map_with_path(
check_carry_same_references, carry_arg, carry_arg_out
check_carry_same_references, carry_arg, carry_arg_out, is_leaf=graph.is_graph_node
)

def _extract_graphdefs(
Expand Down
Loading
Loading