Skip to content

Commit 408bc7f

Browse files
committed
[nnx] don't share Rngs
1 parent 535fc81 commit 408bc7f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+988
-1202
lines changed

docs_nnx/api_reference/flax.nnx/helpers.rst

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ helpers
44
.. automodule:: flax.nnx
55
.. currentmodule:: flax.nnx
66

7-
.. autoclass:: Dict
8-
:members:
9-
.. autoclass:: List
10-
:members:
7+
118
.. autoclass:: Sequential
129
:members:
1310
.. autoclass:: TrainState

docs_nnx/api_reference/flax.nnx/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/i
77
:maxdepth: 3
88

99
graph
10+
object
1011
module
1112
nn/index
1213
rnglib
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
object
2+
------------------------
3+
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
6+
7+
.. autoclass:: Object
8+
:members:
9+
.. autofunction:: data
10+
.. autodata:: Data
11+
:annotation:
12+
.. autofunction:: is_data_type
13+
.. autofunction:: register_data_type

docs_nnx/api_reference/flax.nnx/training/optimizer.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ Optimizer
66

77
.. autoclass:: Optimizer
88
:members: __init__, update
9+
10+
.. autoclass:: PytreeOptimizer
11+
:members: __init__, update

docs_nnx/api_reference/flax.nnx/variables.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ variables
2323

2424
.. autofunction:: variable_name_from_type
2525
.. autofunction:: variable_type_from_name
26-
.. autofunction:: set_variable_name
26+
.. autofunction:: register_variable_name

docs_nnx/guides/filters_guide.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"id": "95b08e64",
66
"metadata": {},
77
"source": [
8-
"# Using Filters, grouping NNX variables \n",
8+
"# Filters\n",
99
"\n",
1010
"Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n",
1111
"\n",

docs_nnx/guides/filters_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jupytext:
88
jupytext_version: 1.13.8
99
---
1010

11-
# Using Filters, grouping NNX variables
11+
# Filters
1212

1313
Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).
1414

examples/nnx_toy_examples/06_scan_over_layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def create_block(rngs: nnx.Rngs):
4747
self.layers = create_block(rngs)
4848

4949
def __call__(self, x: jax.Array) -> jax.Array:
50-
@nnx.split_rngs(splits=self.n_layers)
5150
@nnx.scan
5251
def scan_fn(x: jax.Array, block: Block):
5352
x = block(x)

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,29 @@ def init_optimizer_state(variable: nnx.Variable):
8484

8585
self.lr = lr
8686
self.params = params
87-
self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
87+
self.momentum: nnx.State = jax.tree.map(
88+
init_optimizer_state,
89+
self.params,
90+
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
91+
)
8892
self.decay = decay
8993

9094
def update(self, grads: nnx.State):
9195
def update_fn(
9296
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
9397
):
9498
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
95-
momentum.value = self.decay * momentum + (1 - self.decay) * grad.value
99+
momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...]
96100
# θ_{t+1} = θ_t - α * v_t
97-
params.value -= self.lr * momentum
98-
99-
jax.tree.map(update_fn, self.params, self.momentum, grads)
101+
params[...] -= self.lr * momentum[...]
102+
103+
jax.tree.map(
104+
update_fn,
105+
self.params,
106+
self.momentum,
107+
grads,
108+
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
109+
)
100110

101111

102112
@nnx.jit

examples/nnx_toy_examples/mutable_array_basic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def dataset(batch_size):
3535

3636

3737
class Linear(nnx.Module):
38-
__data__ = ('w', 'b')
39-
4038
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
4139
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
4240
self.b = nnx.Param(jnp.zeros((dout,)))
@@ -50,8 +48,6 @@ class Count(nnx.Variable[nnx.A]):
5048

5149

5250
class MLP(nnx.Module):
53-
__data__ = ('count', 'linear1', 'linear2')
54-
5551
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
5652
self.count = Count(jnp.array(0))
5753
self.linear1 = Linear(din, dhidden, rngs=rngs)

0 commit comments

Comments
 (0)