Skip to content

Commit f29c0b4

Browse files
committed
[nnx] don't share Rngs
1 parent 535fc81 commit f29c0b4

34 files changed

+678
-767
lines changed

docs_nnx/api_reference/flax.nnx/helpers.rst

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ helpers
44
.. automodule:: flax.nnx
55
.. currentmodule:: flax.nnx
66

7-
.. autoclass:: Dict
8-
:members:
9-
.. autoclass:: List
10-
:members:
7+
118
.. autoclass:: Sequential
129
:members:
1310
.. autoclass:: TrainState

docs_nnx/api_reference/flax.nnx/training/optimizer.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ Optimizer
66

77
.. autoclass:: Optimizer
88
:members: __init__, update
9+
10+
.. autoclass:: MutableArrayOptimizer
11+
:members: __init__, update

docs_nnx/guides/filters_guide.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"id": "95b08e64",
66
"metadata": {},
77
"source": [
8-
"# Using Filters, grouping NNX variables \n",
8+
"# Filters\n",
99
"\n",
1010
"Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n",
1111
"\n",

docs_nnx/guides/filters_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jupytext:
88
jupytext_version: 1.13.8
99
---
1010

11-
# Using Filters, grouping NNX variables
11+
# Filters
1212

1313
Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).
1414

examples/nnx_toy_examples/06_scan_over_layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def create_block(rngs: nnx.Rngs):
4747
self.layers = create_block(rngs)
4848

4949
def __call__(self, x: jax.Array) -> jax.Array:
50-
@nnx.split_rngs(splits=self.n_layers)
5150
@nnx.scan
5251
def scan_fn(x: jax.Array, block: Block):
5352
x = block(x)

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,29 @@ def init_optimizer_state(variable: nnx.Variable):
8484

8585
self.lr = lr
8686
self.params = params
87-
self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
87+
self.momentum: nnx.State = jax.tree.map(
88+
init_optimizer_state,
89+
self.params,
90+
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
91+
)
8892
self.decay = decay
8993

9094
def update(self, grads: nnx.State):
9195
def update_fn(
9296
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
9397
):
9498
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
95-
momentum.value = self.decay * momentum + (1 - self.decay) * grad.value
99+
momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...]
96100
# θ_{t+1} = θ_t - α * v_t
97-
params.value -= self.lr * momentum
98-
99-
jax.tree.map(update_fn, self.params, self.momentum, grads)
101+
params[...] -= self.lr * momentum[...]
102+
103+
jax.tree.map(
104+
update_fn,
105+
self.params,
106+
self.momentum,
107+
grads,
108+
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
109+
)
100110

101111

102112
@nnx.jit

examples/nnx_toy_examples/mutable_array_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,10 @@ def update_fn(
259259
# compute the loss by calling the model with the inputs.
260260
@jax.jit
261261
def train_step(model: Model, optimizer: SGD, rngs: nnx.Rngs, x, y):
262-
treedef, params, nondiff = nnx.split(model, nnx.Param, ...)
262+
graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
263263

264264
def loss_fn(params):
265-
model = nnx.merge(treedef, params, nondiff)
265+
model = nnx.merge(graphdef, params, nondiff)
266266
loss = jnp.mean((model(x, rngs=rngs) - y) ** 2)
267267
return loss
268268

flax/nnx/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from .graph import GraphState as GraphState
3232
from .graph import PureState as PureState
3333
from .object import Object as Object
34-
from .helpers import Dict as Dict
3534
from .helpers import Sequential as Sequential
3635
from .helpers import TrainState as TrainState
3736
from .module import M as M
@@ -139,7 +138,7 @@
139138
from .training.metrics import Metric as Metric
140139
from .training.metrics import MultiMetric as MultiMetric
141140
from .training.optimizer import Optimizer as Optimizer
142-
from .training.optimizer import OptaxOptimizer as OptaxOptimizer
141+
from .training.optimizer import MutableArrayOptimizer as MutableArrayOptimizer
143142
from .transforms.autodiff import DiffState as DiffState
144143
from .transforms.autodiff import grad as grad
145144
from .transforms.autodiff import value_and_grad as value_and_grad

flax/nnx/bridge/module.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ class ModuleBase:
224224
@tpe.dataclass_transform(field_specifiers=(dataclasses.field,)) # type: ignore[not-supported-yet]
225225
class Module(nnx_module.Module, ModuleBase, metaclass=ModuleMeta):
226226
def __init_subclass__(cls) -> None:
227-
cls.__data__ = 'auto'
228-
super().__init_subclass__()
227+
super().__init_subclass__(pytree=False)
229228

230229
cls = dataclasses.dataclass(repr=False)(cls)
231230
cls.__hash__ = object.__hash__ # type: ignore[method-assign]

flax/nnx/bridge/wrappers.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from flax.nnx.statelib import State
3131
import jax
3232
from jax import tree_util as jtu
33-
from flax import config
3433

3534
M = tp.TypeVar('M', bound=Module)
3635

@@ -87,9 +86,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
8786
_set_initializing(module, False)
8887
return fn
8988

90-
PYTREE_DEFAULT = 'auto' if config.flax_mutable_array else None
91-
92-
class ToNNX(Module):
89+
class ToNNX(Module, pytree=False):
9390
"""A wrapper to turn any Linen module into an NNX module.
9491
9592
The result NNX module can be used standalone with all NNX APIs, or as a submodule of
@@ -119,8 +116,6 @@ class ToNNX(Module):
119116
A stateful NNX module that behaves the same as the wrapped Linen module.
120117
"""
121118

122-
__data__ = 'auto'
123-
124119
def __init__(
125120
self,
126121
module: linen.Module,

0 commit comments

Comments
 (0)