Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions examples/nnx_toy_examples/07_array_leaves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# %%
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax

from flax import nnx, struct

X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)


def dataset(batch_size):
while True:
idx = np.random.choice(len(X), size=batch_size)
yield X[idx], Y[idx]

class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = jax.random.normal(rngs.params(), (din, dout))
self.b = jnp.zeros((dout,))

def __call__(self, x):
return x @ self.w + self.b


class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
self.count = jnp.array(0)
self.linear1 = Linear(din, dhidden, rngs=rngs)
self.linear2 = Linear(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count += 1
return self.linear2(nnx.relu(self.linear1(x)))

def is_param(path, value):
key = path[-1]
return key == 'w' or key == 'b'

model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=is_param)


@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch

def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)

diff_state = nnx.DiffState(0, is_param)
grads: nnx.State = nnx.grad(loss_fn, argnums=diff_state)(model)
optimizer.update(grads)


@nnx.jit
def test_step(model: MLP, batch):
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}


total_steps = 10_000
for step, batch in enumerate(dataset(32)):
train_step(model, optimizer, batch)

if step % 1000 == 0:
logs = test_step(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break

print('times called:', model.count)

y_pred = model(X)

plt.scatter(X, Y, color='blue')
plt.plot(X, y_pred, color='black')
plt.show()
69 changes: 48 additions & 21 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __treescope_repr__(self, path, subtree_renderer):
@dataclasses.dataclass(frozen=True, repr=False)
class VariableDef(reprlib.Representable, tp.Generic[Node]):
type: type[Node]
index: int
index: int # TODO(cgarciae): make Optional instead of using -1
outer_index: int | None
metadata: HashableMapping[str, tp.Any]

Expand Down Expand Up @@ -320,7 +320,11 @@ class NodeDef(tp.Generic[Node], reprlib.Representable):
attributes: tuple[
tuple[
Key,
NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any] | Static[tp.Any],
NodeDef[tp.Any]
| VariableDef[tp.Any]
| NodeRef[tp.Any]
| Static[tp.Any]
| ArrayAttr,
],
...,
]
Expand Down Expand Up @@ -387,6 +391,7 @@ def __treescope_repr__(self, path, subtree_renderer):
subtree_renderer=subtree_renderer,
)

# TODO(cgarciae): remove this method
def apply(
self, state: GraphState, *states: GraphState
) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]:
Expand All @@ -407,10 +412,16 @@ def _apply(

jax.tree_util.register_static(NodeDef)

@dataclasses.dataclass(frozen=True, slots=True)
class ArrayAttr:
pass


ARRAY_ATTR = ArrayAttr()

GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]]
PureState = tuple[GraphDef[Node], GraphState]


@tp.overload
def flatten(
node: Node,
Expand Down Expand Up @@ -494,7 +505,7 @@ def flatten(
if ref_index is None:
ref_index = RefMap()

leaves: list[StateLeaf | Variable[tp.Any]] = []
leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray] = []
path: list[Key] | None = [] if with_paths else None
paths: list[PathParts] | None = [] if with_paths else None
node_impl = get_node_impl(node)
Expand Down Expand Up @@ -523,7 +534,7 @@ def _graph_flatten(
path: list[Key] | None,
ref_index: RefMap,
ref_outer_index: RefMap | None,
leaves: list[StateLeaf | Variable[tp.Any]],
leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray],
paths: list[PathParts] | None,
return_variables: bool,
) -> NodeDef | NodeRef | VariableDef:
Expand All @@ -539,6 +550,7 @@ def _graph_flatten(
index = len(ref_index)
ref_index[node] = index
else:
# TODO(cgarciae): use None instead of -1
index = -1

if is_variable:
Expand All @@ -565,7 +577,14 @@ def _graph_flatten(
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')

attributes: list[
tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]]
tuple[
Key,
Static[tp.Any]
| ArrayAttr
| NodeDef[tp.Any]
| VariableDef
| NodeRef[tp.Any],
]
] = []

values, metadata = node_impl.flatten(node)
Expand All @@ -585,16 +604,12 @@ def _graph_flatten(
return_variables,
)
attributes.append((key, nodedef))
elif isinstance(value, (jax.Array, np.ndarray)):
if paths is not None:
paths.append(tuple(path)) # type: ignore
attributes.append((key, ARRAY_ATTR))
leaves.append(value)
else:
if isinstance(value, (jax.Array, np.ndarray)):
if path is not None:
path_str = '/'.join(map(str, path))
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
)
else:
raise ValueError(f'Arrays leaves are not supported, found {value}')
# static_fields.append((key, value))
attributes.append((key, Static(value)))

if path is not None:
Expand Down Expand Up @@ -695,9 +710,7 @@ def _graph_fingerprint(
append_fn(variable_index)
for key_value in value._var_metadata.items():
append_fn(key_value)
else:
if isinstance(value, (jax.Array, np.ndarray)):
raise ValueError(f'Arrays leaves are not supported: {value}')
elif not isinstance(value, (jax.Array, np.ndarray)):
append_fn(value)


Expand Down Expand Up @@ -961,6 +974,11 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
for key, value in nodedef.attributes:
if type(value) is Static:
children.append((key, value.value))
elif type(value) is ArrayAttr:
if not leaves:
raise ValueError('Not enough leaves to unflatten the graph')
array = leaves.popleft()
children.append((key, array))
elif type(value) is NodeRef:
children.append((key, index_ref[value.index]))
elif type(value) is NodeDef:
Expand Down Expand Up @@ -1126,8 +1144,16 @@ def _update_variable(node: Variable, value):
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
_graph_update_dynamic(current_value, value)
else:
# case 3: state leaf is being updated
if not isinstance(current_value, Variable):
if isinstance(current_value, jax.Array | np.ndarray):
if isinstance(node_impl, PytreeNodeImpl):
raise ValueError(
f'Cannot set key {key!r} on immutable node of '
f'type {type(node).__name__}'
)
node_impl.set_key(node, key, value)
continue
elif not isinstance(current_value, Variable):
# case 3: state leaf is being updated
raise ValueError(
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
f'{value!r}'
Expand Down Expand Up @@ -1255,7 +1281,8 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
cached_ref_index: RefMap = RefMap()

def create_static_cache(x):
if is_graph_node(x):
# TODO(cgarciae): support Array attribute updates for graph nodes
if is_graph_node(x) or isinstance(x, Variable):
graphdef, flat_state = flatten(
x, with_paths=True, return_variables=True, ref_index=original_ref_index
)
Expand Down
36 changes: 36 additions & 0 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,42 @@ def stateful_linear(w, b, count, x):
self.assertEqual(count.value, 2)
self.assertEqual(y.shape, (1, 3))

def test_array_attributes(self):
class Foo(nnx.Module):
def __init__(self):
self.a = jnp.array(1)
self.b = 'yes'

m = Foo()

graphdef, state = nnx.split(m)

self.assertLen(state, 1)
self.assertIsInstance(state['a'], jax.Array)

m2 = nnx.merge(graphdef, state)

self.assertIsInstance(m2.a, jax.Array)
self.assertEqual(m2.a, 1)
self.assertEqual(m2.b, 'yes')

def test_transform_array_attributes(self):
class Foo(nnx.Module):
def __init__(self):
self.a = jnp.array(1)
self.b = 'yes'

m = Foo()

@nnx.jit
def f(m):
m.a += 1
self.assertEqual(m.b, 'yes')

f(m)

self.assertEqual(m.a, 2)


class SimpleModule(nnx.Module):
pass
Expand Down
Loading