Skip to content

Commit 097fffc

Browse files
author
Flax Authors
committed
Merge pull request #4612 from google:nnx-array-leaves
PiperOrigin-RevId: 740103018
2 parents ba59c33 + dfd1788 commit 097fffc

File tree

3 files changed

+183
-21
lines changed

3 files changed

+183
-21
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %%
16+
import jax
17+
import jax.numpy as jnp
18+
import matplotlib.pyplot as plt
19+
import numpy as np
20+
import optax
21+
22+
from flax import nnx, struct
23+
24+
X = np.linspace(0, 1, 100)[:, None]
25+
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
26+
27+
28+
def dataset(batch_size):
29+
while True:
30+
idx = np.random.choice(len(X), size=batch_size)
31+
yield X[idx], Y[idx]
32+
33+
class Linear(nnx.Module):
34+
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
35+
self.w = jax.random.normal(rngs.params(), (din, dout))
36+
self.b = jnp.zeros((dout,))
37+
38+
def __call__(self, x):
39+
return x @ self.w + self.b
40+
41+
42+
class MLP(nnx.Module):
43+
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
44+
self.count = jnp.array(0)
45+
self.linear1 = Linear(din, dhidden, rngs=rngs)
46+
self.linear2 = Linear(dhidden, dout, rngs=rngs)
47+
48+
def __call__(self, x):
49+
self.count += 1
50+
return self.linear2(nnx.relu(self.linear1(x)))
51+
52+
def is_param(path, value):
53+
key = path[-1]
54+
return key == 'w' or key == 'b'
55+
56+
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
57+
tx = optax.sgd(1e-3)
58+
optimizer = nnx.Optimizer(model, tx, wrt=is_param)
59+
60+
61+
@nnx.jit
62+
def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
63+
x, y = batch
64+
65+
def loss_fn(model: MLP):
66+
y_pred = model(x)
67+
return jnp.mean((y - y_pred) ** 2)
68+
69+
diff_state = nnx.DiffState(0, is_param)
70+
grads: nnx.State = nnx.grad(loss_fn, argnums=diff_state)(model)
71+
optimizer.update(grads)
72+
73+
74+
@nnx.jit
75+
def test_step(model: MLP, batch):
76+
x, y = batch
77+
y_pred = model(x)
78+
loss = jnp.mean((y - y_pred) ** 2)
79+
return {'loss': loss}
80+
81+
82+
total_steps = 10_000
83+
for step, batch in enumerate(dataset(32)):
84+
train_step(model, optimizer, batch)
85+
86+
if step % 1000 == 0:
87+
logs = test_step(model, (X, Y))
88+
print(f"step: {step}, loss: {logs['loss']}")
89+
90+
if step >= total_steps - 1:
91+
break
92+
93+
print('times called:', model.count)
94+
95+
y_pred = model(X)
96+
97+
plt.scatter(X, Y, color='blue')
98+
plt.plot(X, y_pred, color='black')
99+
plt.show()

flax/nnx/graph.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def __treescope_repr__(self, path, subtree_renderer):
268268
@dataclasses.dataclass(frozen=True, repr=False)
269269
class VariableDef(reprlib.Representable, tp.Generic[Node]):
270270
type: type[Node]
271-
index: int
271+
index: int # TODO(cgarciae): make Optional instead of using -1
272272
outer_index: int | None
273273
metadata: HashableMapping[str, tp.Any]
274274

@@ -320,7 +320,11 @@ class NodeDef(tp.Generic[Node], reprlib.Representable):
320320
attributes: tuple[
321321
tuple[
322322
Key,
323-
NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any] | Static[tp.Any],
323+
NodeDef[tp.Any]
324+
| VariableDef[tp.Any]
325+
| NodeRef[tp.Any]
326+
| Static[tp.Any]
327+
| ArrayAttr,
324328
],
325329
...,
326330
]
@@ -387,6 +391,7 @@ def __treescope_repr__(self, path, subtree_renderer):
387391
subtree_renderer=subtree_renderer,
388392
)
389393

394+
# TODO(cgarciae): remove this method
390395
def apply(
391396
self, state: GraphState, *states: GraphState
392397
) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]:
@@ -407,10 +412,16 @@ def _apply(
407412

408413
jax.tree_util.register_static(NodeDef)
409414

415+
@dataclasses.dataclass(frozen=True, slots=True)
416+
class ArrayAttr:
417+
pass
418+
419+
420+
ARRAY_ATTR = ArrayAttr()
421+
410422
GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]]
411423
PureState = tuple[GraphDef[Node], GraphState]
412424

413-
414425
@tp.overload
415426
def flatten(
416427
node: Node,
@@ -494,7 +505,7 @@ def flatten(
494505
if ref_index is None:
495506
ref_index = RefMap()
496507

497-
leaves: list[StateLeaf | Variable[tp.Any]] = []
508+
leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray] = []
498509
path: list[Key] | None = [] if with_paths else None
499510
paths: list[PathParts] | None = [] if with_paths else None
500511
node_impl = get_node_impl(node)
@@ -523,7 +534,7 @@ def _graph_flatten(
523534
path: list[Key] | None,
524535
ref_index: RefMap,
525536
ref_outer_index: RefMap | None,
526-
leaves: list[StateLeaf | Variable[tp.Any]],
537+
leaves: list[StateLeaf | Variable[tp.Any] | jax.Array | np.ndarray],
527538
paths: list[PathParts] | None,
528539
return_variables: bool,
529540
) -> NodeDef | NodeRef | VariableDef:
@@ -539,6 +550,7 @@ def _graph_flatten(
539550
index = len(ref_index)
540551
ref_index[node] = index
541552
else:
553+
# TODO(cgarciae): use None instead of -1
542554
index = -1
543555

544556
if is_variable:
@@ -565,7 +577,14 @@ def _graph_flatten(
565577
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
566578

567579
attributes: list[
568-
tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]]
580+
tuple[
581+
Key,
582+
Static[tp.Any]
583+
| ArrayAttr
584+
| NodeDef[tp.Any]
585+
| VariableDef
586+
| NodeRef[tp.Any],
587+
]
569588
] = []
570589

571590
values, metadata = node_impl.flatten(node)
@@ -585,16 +604,12 @@ def _graph_flatten(
585604
return_variables,
586605
)
587606
attributes.append((key, nodedef))
607+
elif isinstance(value, (jax.Array, np.ndarray)):
608+
if paths is not None:
609+
paths.append(tuple(path)) # type: ignore
610+
attributes.append((key, ARRAY_ATTR))
611+
leaves.append(value)
588612
else:
589-
if isinstance(value, (jax.Array, np.ndarray)):
590-
if path is not None:
591-
path_str = '/'.join(map(str, path))
592-
raise ValueError(
593-
f'Arrays leaves are not supported, at {path_str!r}: {value}'
594-
)
595-
else:
596-
raise ValueError(f'Arrays leaves are not supported, found {value}')
597-
# static_fields.append((key, value))
598613
attributes.append((key, Static(value)))
599614

600615
if path is not None:
@@ -695,9 +710,7 @@ def _graph_fingerprint(
695710
append_fn(variable_index)
696711
for key_value in value._var_metadata.items():
697712
append_fn(key_value)
698-
else:
699-
if isinstance(value, (jax.Array, np.ndarray)):
700-
raise ValueError(f'Arrays leaves are not supported: {value}')
713+
elif not isinstance(value, (jax.Array, np.ndarray)):
701714
append_fn(value)
702715

703716

@@ -961,6 +974,11 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
961974
for key, value in nodedef.attributes:
962975
if type(value) is Static:
963976
children.append((key, value.value))
977+
elif type(value) is ArrayAttr:
978+
if not leaves:
979+
raise ValueError('Not enough leaves to unflatten the graph')
980+
array = leaves.popleft()
981+
children.append((key, array))
964982
elif type(value) is NodeRef:
965983
children.append((key, index_ref[value.index]))
966984
elif type(value) is NodeDef:
@@ -1126,8 +1144,16 @@ def _update_variable(node: Variable, value):
11261144
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
11271145
_graph_update_dynamic(current_value, value)
11281146
else:
1129-
# case 3: state leaf is being updated
1130-
if not isinstance(current_value, Variable):
1147+
if isinstance(current_value, jax.Array | np.ndarray):
1148+
if isinstance(node_impl, PytreeNodeImpl):
1149+
raise ValueError(
1150+
f'Cannot set key {key!r} on immutable node of '
1151+
f'type {type(node).__name__}'
1152+
)
1153+
node_impl.set_key(node, key, value)
1154+
continue
1155+
elif not isinstance(current_value, Variable):
1156+
# case 3: state leaf is being updated
11311157
raise ValueError(
11321158
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
11331159
f'{value!r}'
@@ -1255,7 +1281,8 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
12551281
cached_ref_index: RefMap = RefMap()
12561282

12571283
def create_static_cache(x):
1258-
if is_graph_node(x):
1284+
# TODO(cgarciae): support Array attribute updates for graph nodes
1285+
if is_graph_node(x) or isinstance(x, Variable):
12591286
graphdef, flat_state = flatten(
12601287
x, with_paths=True, return_variables=True, ref_index=original_ref_index
12611288
)

tests/nnx/graph_utils_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,42 @@ def stateful_linear(w, b, count, x):
10111011
self.assertEqual(count.value, 2)
10121012
self.assertEqual(y.shape, (1, 3))
10131013

1014+
def test_array_attributes(self):
1015+
class Foo(nnx.Module):
1016+
def __init__(self):
1017+
self.a = jnp.array(1)
1018+
self.b = 'yes'
1019+
1020+
m = Foo()
1021+
1022+
graphdef, state = nnx.split(m)
1023+
1024+
self.assertLen(state, 1)
1025+
self.assertIsInstance(state['a'], jax.Array)
1026+
1027+
m2 = nnx.merge(graphdef, state)
1028+
1029+
self.assertIsInstance(m2.a, jax.Array)
1030+
self.assertEqual(m2.a, 1)
1031+
self.assertEqual(m2.b, 'yes')
1032+
1033+
def test_transform_array_attributes(self):
1034+
class Foo(nnx.Module):
1035+
def __init__(self):
1036+
self.a = jnp.array(1)
1037+
self.b = 'yes'
1038+
1039+
m = Foo()
1040+
1041+
@nnx.jit
1042+
def f(m):
1043+
m.a += 1
1044+
self.assertEqual(m.b, 'yes')
1045+
1046+
f(m)
1047+
1048+
self.assertEqual(m.a, 2)
1049+
10141050

10151051
class SimpleModule(nnx.Module):
10161052
pass

0 commit comments

Comments
 (0)