Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions docs_nnx/api_reference/flax.nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@ transforms

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Jit
:members:
.. autoclass:: Remat
:members:
.. autoclass:: Scan
:members:
.. autoclass:: Vmap
:members:

.. autofunction:: grad
.. autofunction:: jit
.. autofunction:: shard_map
Expand Down
2 changes: 0 additions & 2 deletions docs_nnx/api_reference/flax.nnx/variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ variables
:members:
.. autoclass:: Cache
:members:
.. autoclass:: Empty
:members:
.. autoclass:: Intermediate
:members:
.. autoclass:: Param
Expand Down
11 changes: 4 additions & 7 deletions examples/nnx_toy_examples/01_functional_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from flax import nnx

X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)


def dataset(batch_size):
Expand Down Expand Up @@ -50,11 +50,8 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
self.linear2 = Linear(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
x = self.linear1(x)
x = jax.nn.relu(x)
x = self.linear2(x)
return x
self.count[...] += 1
return self.linear2(jax.nn.relu(self.linear1(x) * 0.5))


graphdef, params, counts = nnx.split(
Expand Down
45 changes: 5 additions & 40 deletions examples/nnx_toy_examples/mutable_array_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,7 @@
import numpy as np

from flax import nnx
from flax.nnx.variablelib import is_mutable_array


def mutable_like(path, x):
return (
isinstance(x, nnx.Variable) and x.mutable
) or nnx.variablelib.is_mutable_array(x)


def freeze(x, only: nnx.filterlib.Filter = mutable_like):
freeze_filter = nnx.filterlib.to_predicate(only)
mutable_arrays: set[int] = set()

def check_mutable_array(path, x):
m_array_id = id(x)
if m_array_id in mutable_arrays:
path_str = jax.tree_util.keystr(path)
raise ValueError(
f'Found duplicate MutableArray found at path {path_str}: {x}'
)
mutable_arrays.add(m_array_id)

def _freeze_fn(jax_path, x):
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
if freeze_filter(path, x):
if isinstance(x, nnx.Variable):
check_mutable_array(jax_path, x.raw_value)
return x.from_metadata(x[...], x.get_metadata().copy())
elif nnx.variablelib.is_mutable_array(x):
check_mutable_array(jax_path, x)
return x[...]
return x

return jax.tree.map_with_path(
_freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
)


X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
Expand Down Expand Up @@ -94,21 +59,21 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):

def __call__(self, x):
self.count[...] += 1
return self.linear2(jax.nn.gelu(self.linear1(x)) * 0.5)
return self.linear2(jax.nn.relu(self.linear1(x)) * 0.5)


model = MLP(din=1, dhidden=64, dout=1, rngs=nnx.Rngs(0))
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))


@jax.jit
def train_step(model, x, y):
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
graphdef, params, counts = nnx.pure(nnx.split(model, nnx.Param, Count))

def loss_fn(params):
model = nnx.merge(graphdef, params, counts)
return jnp.mean((y - model(x)) ** 2)

grads = jax.grad(loss_fn)(freeze(params))
grads = jax.grad(loss_fn)(nnx.freeze(params))

def sgd(w, g):
w[...] -= 0.1 * g[...]
Expand Down
Loading
Loading