Skip to content

Commit 18d3750

Browse files
committed
[nnx] support Array leaves in graph nodes
1 parent 2894b94 commit 18d3750

File tree

2 files changed

+112
-11
lines changed

2 files changed

+112
-11
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %%
16+
import jax
17+
import jax.numpy as jnp
18+
import matplotlib.pyplot as plt
19+
import numpy as np
20+
import optax
21+
22+
from flax import nnx, struct
23+
24+
X = np.linspace(0, 1, 100)[:, None]
25+
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
26+
27+
28+
def dataset(batch_size):
29+
while True:
30+
idx = np.random.choice(len(X), size=batch_size)
31+
yield X[idx], Y[idx]
32+
33+
class Linear(nnx.Module):
34+
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
35+
self.w = jax.random.normal(rngs.params(), (din, dout))
36+
self.b = jnp.zeros((dout,))
37+
38+
def __call__(self, x):
39+
return x @ self.w + self.b
40+
41+
42+
class MLP(nnx.Module):
43+
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
44+
self.count = jnp.array(0)
45+
self.linear1 = Linear(din, dhidden, rngs=rngs)
46+
self.linear2 = Linear(dhidden, dout, rngs=rngs)
47+
48+
def __call__(self, x):
49+
self.count += 1
50+
return self.linear2(nnx.relu(self.linear1(x)))
51+
52+
def is_param(path, value):
53+
key = path[-1]
54+
return key == 'w' or key == 'b'
55+
56+
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
57+
tx = optax.sgd(1e-3)
58+
optimizer = nnx.Optimizer(model, tx, wrt=is_param)
59+
60+
61+
@nnx.jit
62+
def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
63+
x, y = batch
64+
65+
def loss_fn(model: MLP):
66+
y_pred = model(x)
67+
return jnp.mean((y - y_pred) ** 2)
68+
69+
diff_state = nnx.DiffState(0, is_param)
70+
grads: nnx.State = nnx.grad(loss_fn, argnums=diff_state)(model)
71+
optimizer.update(grads)
72+
73+
74+
@nnx.jit
75+
def test_step(model: MLP, batch):
76+
x, y = batch
77+
y_pred = model(x)
78+
loss = jnp.mean((y - y_pred) ** 2)
79+
return {'loss': loss}
80+
81+
82+
total_steps = 10_000
83+
for step, batch in enumerate(dataset(32)):
84+
train_step(model, optimizer, batch)
85+
86+
if step % 1000 == 0:
87+
logs = test_step(model, (X, Y))
88+
print(f"step: {step}, loss: {logs['loss']}")
89+
90+
if step >= total_steps - 1:
91+
break
92+
93+
print('times called:', model.count)
94+
95+
y_pred = model(X)
96+
97+
plt.scatter(X, Y, color='blue')
98+
plt.plot(X, y_pred, color='black')
99+
plt.show()

flax/nnx/graph.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -710,9 +710,7 @@ def _graph_fingerprint(
710710
append_fn(variable_index)
711711
for key_value in value._var_metadata.items():
712712
append_fn(key_value)
713-
else:
714-
if isinstance(value, (jax.Array, np.ndarray)):
715-
raise ValueError(f'Arrays leaves are not supported: {value}')
713+
elif not isinstance(value, (jax.Array, np.ndarray)):
716714
append_fn(value)
717715

718716

@@ -1146,8 +1144,16 @@ def _update_variable(node: Variable, value):
11461144
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
11471145
_graph_update_dynamic(current_value, value)
11481146
else:
1149-
# case 3: state leaf is being updated
1150-
if not isinstance(current_value, Variable):
1147+
if isinstance(current_value, jax.Array | np.ndarray):
1148+
if isinstance(node_impl, PytreeNodeImpl):
1149+
raise ValueError(
1150+
f'Cannot set key {key!r} on immutable node of '
1151+
f'type {type(node).__name__}'
1152+
)
1153+
node_impl.set_key(node, key, value)
1154+
continue
1155+
elif not isinstance(current_value, Variable):
1156+
# case 3: state leaf is being updated
11511157
raise ValueError(
11521158
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
11531159
f'{value!r}'
@@ -1275,7 +1281,8 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
12751281
cached_ref_index: RefMap = RefMap()
12761282

12771283
def create_static_cache(x):
1278-
if is_graph_node(x):
1284+
# TODO(cgarciae): support Array attribute updates for graph nodes
1285+
if is_graph_node(x) or isinstance(x, Variable):
12791286
graphdef, flat_state = flatten(
12801287
x, with_paths=True, return_variables=True, ref_index=original_ref_index
12811288
)
@@ -1284,11 +1291,6 @@ def create_static_cache(x):
12841291
# clone but keep the same variable references
12851292
node_cache = unflatten(graphdef, flat_state, index_ref=index_ref)
12861293
cached_new_ref_index = RefMap()
1287-
_fp = fingerprint(
1288-
node_cache,
1289-
ref_index=cached_ref_index,
1290-
new_ref_index=cached_new_ref_index,
1291-
)
12921294
cached_ref_index.update(cached_new_ref_index)
12931295
cache[node_cache] = StaticCache.create(
12941296
graphdef, paths, variables, cached_new_ref_index

0 commit comments

Comments
 (0)