|
31 | 31 | # # Utils |
32 | 32 | A = TypeVar('A') |
33 | 33 |
|
34 | | -def mutable_like(path, x): |
35 | | - return (isinstance(x, nnx.Variable) and x.mutable) or nnx.is_mutable_array(x) |
36 | | - |
37 | | - |
38 | | -def freeze(x: A, only: nnx.filterlib.Filter = mutable_like) -> A: |
39 | | - freeze_filter = nnx.filterlib.to_predicate(only) |
40 | | - mutable_arrays: set[int] = set() |
41 | | - |
42 | | - def check_mutable_array(path, x): |
43 | | - m_array_id = id(x) |
44 | | - if m_array_id in mutable_arrays: |
45 | | - path_str = jax.tree_util.keystr(path) |
46 | | - raise ValueError( |
47 | | - f'Found duplicate MutableArray found at path {path_str}: {x}' |
48 | | - ) |
49 | | - mutable_arrays.add(m_array_id) |
50 | | - |
51 | | - def _freeze_fn(jax_path, x): |
52 | | - path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path) |
53 | | - if freeze_filter(path, x): |
54 | | - if isinstance(x, nnx.Variable): |
55 | | - check_mutable_array(jax_path, x.raw_value) |
56 | | - return x.from_metadata(x[...], x.get_metadata().copy()) |
57 | | - elif nnx.is_mutable_array(x): |
58 | | - check_mutable_array(jax_path, x) |
59 | | - return x[...] |
60 | | - return x |
61 | | - |
62 | | - return jax.tree.map_with_path( |
63 | | - _freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable) |
64 | | - ) |
65 | | - |
66 | | - |
67 | | -def array_like(path, x): |
68 | | - return ( |
69 | | - isinstance(x, nnx.Variable) and not x.mutable |
70 | | - ) or nnx.is_mutable_array(x) |
71 | | - |
72 | | - |
73 | | -def mutable(x: A, only: nnx.filterlib.Filter = array_like) -> A: |
74 | | - freeze_filter = nnx.filterlib.to_predicate(only) |
75 | | - mutable_arrays: dict[int, Any] = {} |
76 | | - |
77 | | - def get_mutable(x): |
78 | | - m_array_id = id(x) |
79 | | - if m_array_id in mutable_arrays: |
80 | | - return mutable_arrays[m_array_id] |
81 | | - |
82 | | - if isinstance(x, nnx.Variable): |
83 | | - assert not x.mutable |
84 | | - _mutable = x.from_metadata( |
85 | | - nnx.mutable_array(x.raw_value), |
86 | | - x.get_metadata().copy(), |
87 | | - ) |
88 | | - mutable_arrays[m_array_id] = _mutable |
89 | | - return _mutable |
90 | | - elif isinstance(x, jax.Array): |
91 | | - _mutable = nnx.mutable_array(x) |
92 | | - mutable_arrays[m_array_id] = _mutable |
93 | | - return _mutable |
94 | | - return x |
95 | | - |
96 | | - def _mutable_fn(jax_path, x): |
97 | | - path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path) |
98 | | - if freeze_filter(path, x): |
99 | | - return get_mutable(x) |
100 | | - return x |
101 | | - |
102 | | - return jax.tree.map_with_path( |
103 | | - _mutable_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable) |
104 | | - ) |
105 | | - |
106 | | -def pure(tree: A) -> A: |
107 | | - def _pure_fn(x): |
108 | | - if isinstance(x, nnx.Variable | nnx.VariableState): |
109 | | - return x.raw_value |
110 | | - return x |
111 | | - |
112 | | - return jax.tree.map( |
113 | | - _pure_fn, tree, is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState) |
114 | | - ) |
115 | | - |
116 | 34 | def fork_rngs( |
117 | 35 | rngs: nnx.Rngs, |
118 | 36 | /, |
|
0 commit comments