Skip to content

Commit 6ef64e3

Browse files
committed
move freeze, mutable, and pure to graph.py
1 parent 901880a commit 6ef64e3

File tree

7 files changed

+527
-151
lines changed

7 files changed

+527
-151
lines changed

examples/nnx_toy_examples/mutable_array_demo.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -31,88 +31,6 @@
3131
# # Utils
3232
A = TypeVar('A')
3333

34-
def mutable_like(path, x):
35-
return (isinstance(x, nnx.Variable) and x.mutable) or nnx.is_mutable_array(x)
36-
37-
38-
def freeze(x: A, only: nnx.filterlib.Filter = mutable_like) -> A:
39-
freeze_filter = nnx.filterlib.to_predicate(only)
40-
mutable_arrays: set[int] = set()
41-
42-
def check_mutable_array(path, x):
43-
m_array_id = id(x)
44-
if m_array_id in mutable_arrays:
45-
path_str = jax.tree_util.keystr(path)
46-
raise ValueError(
47-
f'Found duplicate MutableArray found at path {path_str}: {x}'
48-
)
49-
mutable_arrays.add(m_array_id)
50-
51-
def _freeze_fn(jax_path, x):
52-
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
53-
if freeze_filter(path, x):
54-
if isinstance(x, nnx.Variable):
55-
check_mutable_array(jax_path, x.raw_value)
56-
return x.from_metadata(x[...], x.get_metadata().copy())
57-
elif nnx.is_mutable_array(x):
58-
check_mutable_array(jax_path, x)
59-
return x[...]
60-
return x
61-
62-
return jax.tree.map_with_path(
63-
_freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
64-
)
65-
66-
67-
def array_like(path, x):
68-
return (
69-
isinstance(x, nnx.Variable) and not x.mutable
70-
) or nnx.is_mutable_array(x)
71-
72-
73-
def mutable(x: A, only: nnx.filterlib.Filter = array_like) -> A:
74-
freeze_filter = nnx.filterlib.to_predicate(only)
75-
mutable_arrays: dict[int, Any] = {}
76-
77-
def get_mutable(x):
78-
m_array_id = id(x)
79-
if m_array_id in mutable_arrays:
80-
return mutable_arrays[m_array_id]
81-
82-
if isinstance(x, nnx.Variable):
83-
assert not x.mutable
84-
_mutable = x.from_metadata(
85-
nnx.mutable_array(x.raw_value),
86-
x.get_metadata().copy(),
87-
)
88-
mutable_arrays[m_array_id] = _mutable
89-
return _mutable
90-
elif isinstance(x, jax.Array):
91-
_mutable = nnx.mutable_array(x)
92-
mutable_arrays[m_array_id] = _mutable
93-
return _mutable
94-
return x
95-
96-
def _mutable_fn(jax_path, x):
97-
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
98-
if freeze_filter(path, x):
99-
return get_mutable(x)
100-
return x
101-
102-
return jax.tree.map_with_path(
103-
_mutable_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
104-
)
105-
106-
def pure(tree: A) -> A:
107-
def _pure_fn(x):
108-
if isinstance(x, nnx.Variable | nnx.VariableState):
109-
return x.raw_value
110-
return x
111-
112-
return jax.tree.map(
113-
_pure_fn, tree, is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState)
114-
)
115-
11634
def fork_rngs(
11735
rngs: nnx.Rngs,
11836
/,

flax/nnx/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
from .graph import MergeContext as MergeContext
5454
from .graph import merge_context as merge_context
5555
from .graph import variables as variables
56+
from .graph import freeze as freeze
57+
from .graph import mutable as mutable
58+
from .graph import pure as pure
5659
from .graph import cached_partial as cached_partial
5760
from .nn import initializers as initializers
5861
from .nn.activations import celu as celu
@@ -168,6 +171,7 @@
168171
from .variablelib import variable_name_from_type as variable_name_from_type
169172
from .variablelib import register_variable_name as register_variable_name
170173
from .variablelib import mutable_array as mutable_array
174+
from .variablelib import MutableArray as MutableArray
171175
from .variablelib import is_mutable_array as is_mutable_array
172176
from .visualization import display as display
173177
from .extract import to_tree as to_tree

0 commit comments

Comments
 (0)