diff --git a/examples/nnx_toy_examples/07_array_leaves.py b/examples/nnx_toy_examples/07_array_leaves.py new file mode 100644 index 000000000..20b6a7d9e --- /dev/null +++ b/examples/nnx_toy_examples/07_array_leaves.py @@ -0,0 +1,99 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax + +from flax import nnx, struct + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = jax.random.normal(rngs.params(), (din, dout)) + self.b = jnp.zeros((dout,)) + + def __call__(self, x): + return x @ self.w + self.b + + +class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.count = jnp.array(0) + self.linear1 = Linear(din, dhidden, rngs=rngs) + self.linear2 = Linear(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count += 1 + return self.linear2(nnx.relu(self.linear1(x))) + +def is_param(path, value): + key = path[-1] + return key == 'w' or key == 'b' + +model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) +tx = optax.sgd(1e-3) +optimizer = nnx.Optimizer(model, tx, wrt=is_param) + + +@nnx.jit +def train_step(model: MLP, optimizer: nnx.Optimizer, batch): + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + diff_state = nnx.DiffState(0, is_param) + grads: nnx.State = nnx.grad(loss_fn, argnums=diff_state)(model) + optimizer.update(grads) + + +@nnx.jit +def test_step(model: MLP, batch): + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + train_step(model, optimizer, batch) + + if step % 1000 == 0: + logs = test_step(model, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +print('times called:', model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color='blue') +plt.plot(X, y_pred, color='black') +plt.show() diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index c2ce79695..b4484d730 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -268,7 +268,7 @@ def __treescope_repr__(self, path, subtree_renderer): @dataclasses.dataclass(frozen=True, repr=False) class VariableDef(reprlib.Representable, tp.Generic[Node]): type: type[Node] - index: int + index: int # TODO(cgarciae): make Optional instead of using -1 outer_index: int | None metadata: HashableMapping[str, tp.Any] @@ -320,7 +320,11 @@ class NodeDef(tp.Generic[Node], reprlib.Representable): attributes: tuple[ tuple[ Key, - NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any] | Static[tp.Any], + NodeDef[tp.Any] + | VariableDef[tp.Any] + | NodeRef[tp.Any] + | Static[tp.Any] + | ArrayAttr, ], ..., ] @@ -387,6 +391,7 @@ def __treescope_repr__(self, path, subtree_renderer): subtree_renderer=subtree_renderer, ) + # TODO(cgarciae): remove this method def apply( self, state: GraphState, *states: GraphState ) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]: @@ -407,10 +412,16 @@ def _apply( jax.tree_util.register_static(NodeDef) +@dataclasses.dataclass(frozen=True, slots=True) +class ArrayAttr: + pass + + +ARRAY_ATTR = ArrayAttr() + GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]] PureState = tuple[GraphDef[Node], GraphState] - @tp.overload def flatten( node: Node, @@ -494,7 +505,7 @@ def flatten( if ref_index is None: ref_index = RefMap() - leaves: list[StateLeaf | Variable[tp.Any]] = [] + leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray] = [] path: list[Key] | None = [] if with_paths else None paths: list[PathParts] | None = [] if with_paths else None node_impl = get_node_impl(node) @@ -523,7 +534,7 @@ def _graph_flatten( path: list[Key] | None, ref_index: RefMap, ref_outer_index: RefMap | None, - leaves: list[StateLeaf | Variable[tp.Any]], + leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray], paths: list[PathParts] | None, return_variables: bool, ) -> NodeDef | NodeRef | VariableDef: @@ -539,6 +550,7 @@ def _graph_flatten( index = len(ref_index) ref_index[node] = index else: + # TODO(cgarciae): use None instead of -1 index = -1 if is_variable: @@ -565,7 +577,14 @@ def _graph_flatten( raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') attributes: list[ - tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]] + tuple[ + Key, + Static[tp.Any] + | ArrayAttr + | NodeDef[tp.Any] + | VariableDef + | NodeRef[tp.Any], + ] ] = [] values, metadata = node_impl.flatten(node) @@ -585,16 +604,12 @@ def _graph_flatten( return_variables, ) attributes.append((key, nodedef)) + elif isinstance(value, (jax.Array, np.ndarray)): + if paths is not None: + paths.append(tuple(path)) # type: ignore + attributes.append((key, ARRAY_ATTR)) + leaves.append(value) else: - if isinstance(value, (jax.Array, np.ndarray)): - if path is not None: - path_str = '/'.join(map(str, path)) - raise ValueError( - f'Arrays leaves are not supported, at {path_str!r}: {value}' - ) - else: - raise ValueError(f'Arrays leaves are not supported, found {value}') - # static_fields.append((key, value)) attributes.append((key, Static(value))) if path is not None: @@ -695,9 +710,7 @@ def _graph_fingerprint( append_fn(variable_index) for key_value in value._var_metadata.items(): append_fn(key_value) - else: - if isinstance(value, (jax.Array, np.ndarray)): - raise ValueError(f'Arrays leaves are not supported: {value}') + elif not isinstance(value, (jax.Array, np.ndarray)): append_fn(value) @@ -961,6 +974,11 @@ def _get_children() -> list[tuple[Key, tp.Any]]: for key, value in nodedef.attributes: if type(value) is Static: children.append((key, value.value)) + elif type(value) is ArrayAttr: + if not leaves: + raise ValueError('Not enough leaves to unflatten the graph') + array = leaves.popleft() + children.append((key, array)) elif type(value) is NodeRef: children.append((key, index_ref[value.index])) elif type(value) is NodeDef: @@ -1126,8 +1144,16 @@ def _update_variable(node: Variable, value): raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}') _graph_update_dynamic(current_value, value) else: - # case 3: state leaf is being updated - if not isinstance(current_value, Variable): + if isinstance(current_value, jax.Array | np.ndarray): + if isinstance(node_impl, PytreeNodeImpl): + raise ValueError( + f'Cannot set key {key!r} on immutable node of ' + f'type {type(node).__name__}' + ) + node_impl.set_key(node, key, value) + continue + elif not isinstance(current_value, Variable): + # case 3: state leaf is being updated raise ValueError( f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' @@ -1255,7 +1281,8 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args): cached_ref_index: RefMap = RefMap() def create_static_cache(x): - if is_graph_node(x): + # TODO(cgarciae): support Array attribute updates for graph nodes + if is_graph_node(x) or isinstance(x, Variable): graphdef, flat_state = flatten( x, with_paths=True, return_variables=True, ref_index=original_ref_index ) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index af4c5f53d..739177139 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1011,6 +1011,42 @@ def stateful_linear(w, b, count, x): self.assertEqual(count.value, 2) self.assertEqual(y.shape, (1, 3)) + def test_array_attributes(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.array(1) + self.b = 'yes' + + m = Foo() + + graphdef, state = nnx.split(m) + + self.assertLen(state, 1) + self.assertIsInstance(state['a'], jax.Array) + + m2 = nnx.merge(graphdef, state) + + self.assertIsInstance(m2.a, jax.Array) + self.assertEqual(m2.a, 1) + self.assertEqual(m2.b, 'yes') + + def test_transform_array_attributes(self): + class Foo(nnx.Module): + def __init__(self): + self.a = jnp.array(1) + self.b = 'yes' + + m = Foo() + + @nnx.jit + def f(m): + m.a += 1 + self.assertEqual(m.b, 'yes') + + f(m) + + self.assertEqual(m.a, 2) + class SimpleModule(nnx.Module): pass