@@ -2652,10 +2652,22 @@ def clone(node: Node) -> Node:
26522652 graphdef , state = split (node )
26532653 return merge (graphdef , state )
26542654
2655+ def find_duplicates (tree ) -> tuple [str , str ] | None :
2656+ mutable_arrays : dict [int , str ] = {}
2657+ paths_leaves = jax .tree .leaves_with_path (tree )
2658+ for path , x in paths_leaves :
2659+ m_array_id = id (x )
2660+ if m_array_id in mutable_arrays :
2661+ current_path_str = jax .tree_util .keystr (path )
2662+ previous_path_str = mutable_arrays [m_array_id ]
2663+ return current_path_str , previous_path_str
2664+ mutable_arrays [m_array_id ] = jax .tree_util .keystr (path )
2665+
2666+ return None
26552667
26562668def _mutable_like (path , x ):
26572669 return (
2658- isinstance (x , Variable ) and x .mutable
2670+ isinstance (x , Variable | VariableState ) and x .mutable
26592671 ) or variablelib .is_mutable_array (x )
26602672
26612673
@@ -2681,7 +2693,7 @@ def freeze(tree: A, /, only: filterlib.Filter = _mutable_like) -> A:
26812693 ... nnx.freeze(tree)
26822694 ... except ValueError as e:
26832695 ... print(e)
2684- Found duplicate MutableArray found at path [1] and [0] .. .
2696+ Found duplicate at path ' [1]' and ' [0]' .
26852697
26862698 ``only`` is a `Filter <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__
26872699 that can be used to specify which mutable arrays to freeze::
@@ -2698,45 +2710,36 @@ def freeze(tree: A, /, only: filterlib.Filter = _mutable_like) -> A:
26982710 Returns:
26992711 A pytree with the frozen arrays.
27002712 """
2713+ if (duplicate := find_duplicates (tree )) is not None :
2714+ current_path_str , previous_path_str = duplicate
2715+ raise ValueError (
2716+ f"Found duplicate at path '{ current_path_str } ' "
2717+ f"and '{ previous_path_str } '."
2718+ )
27012719 freeze_filter = filterlib .to_predicate (only )
2702- mutable_arrays : dict [int , str ] = {}
2703-
2704- def check_mutable_array (path , x ):
2705- m_array_id = id (x )
2706- if m_array_id in mutable_arrays :
2707- current_path_str = jax .tree_util .keystr (path )
2708- previous_path_str = mutable_arrays [m_array_id ]
2709- raise ValueError (
2710- f'Found duplicate MutableArray found at path { current_path_str } '
2711- f'and { previous_path_str } at object { x } .'
2712- )
2713- mutable_arrays [m_array_id ] = jax .tree_util .keystr (path )
27142720
27152721 def _freeze_fn (jax_path , x ):
2716- path = tuple ( _key_path_to_key ( part ) for part in jax_path )
2722+ path = jax_to_nnx_path ( jax_path )
27172723 if freeze_filter (path , x ):
2718- if isinstance (x , Variable ):
2719- check_mutable_array (jax_path , x .raw_value )
2720- return x .from_metadata (x [...], x .get_metadata ().copy ())
2721- elif variablelib .is_mutable_array (x ):
2722- check_mutable_array (jax_path , x )
2723- return x [...]
2724+ x = jax .tree .map (lambda x : x [...], x )
2725+ elif isinstance (x , Variable | VariableState ):
2726+ x = jax .tree .map (lambda x : x , x )
27242727 return x
27252728
27262729 tree = jax .tree .map_with_path (
2727- _freeze_fn , tree , is_leaf = lambda x : isinstance (x , Variable )
2730+ _freeze_fn , tree , is_leaf = lambda x : isinstance (x , Variable | VariableState )
27282731 )
27292732 return tree
27302733
27312734
27322735def _array_like (path , x ):
27332736 return (
2734- isinstance (x , Variable ) and isinstance ( x . raw_value , jax . Array )
2737+ isinstance (x , Variable | VariableState ) and not x . mutable
27352738 ) or isinstance (x , jax .Array )
27362739
27372740
27382741def mutable (tree : A , / , only : filterlib .Filter = _array_like ) -> A :
2739- """Converts a pytree of arrays to mutable arrays.
2742+ """Converts a tree of arrays to mutable arrays.
27402743
27412744 Example::
27422745
@@ -2757,7 +2760,7 @@ def mutable(tree: A, /, only: filterlib.Filter = _array_like) -> A:
27572760 ... nnx.mutable(tree)
27582761 ... except ValueError as e:
27592762 ... print(e)
2760- Found duplicate Array found at path [1] and [0] .. .
2763+ Found duplicate at path ' [1]' and ' [0]' .
27612764
27622765 ``only`` is a `Filter <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__
27632766 that can be used to specify which arrays to convert to mutable arrays.
@@ -2774,34 +2777,24 @@ def mutable(tree: A, /, only: filterlib.Filter = _array_like) -> A:
27742777 Returns:
27752778 A pytree with the mutable arrays.
27762779 """
2780+ if (duplicate := find_duplicates (tree )) is not None :
2781+ current_path_str , previous_path_str = duplicate
2782+ raise ValueError (
2783+ f"Found duplicate at path '{ current_path_str } ' "
2784+ f"and '{ previous_path_str } '."
2785+ )
27772786 mutable_filter = filterlib .to_predicate (only )
2778- arrays : dict [int , str ] = {}
2779-
2780- def check_array (path , x ):
2781- m_array_id = id (x )
2782- if m_array_id in arrays :
2783- current_path_str = jax .tree_util .keystr (path )
2784- previous_path_str = arrays [m_array_id ]
2785- raise ValueError (
2786- f'Found duplicate Array found at path { current_path_str } '
2787- f'and { previous_path_str } at object { x } .'
2788- )
2789- arrays [m_array_id ] = jax .tree_util .keystr (path )
27902787
27912788 def _mutable_fn (jax_path , x ):
2792- path = tuple ( _key_path_to_key ( part ) for part in jax_path )
2789+ path = jax_to_nnx_path ( jax_path )
27932790 if mutable_filter (path , x ):
2794- if isinstance (x , Variable ) and isinstance (x .raw_value , jax .Array ):
2795- check_array (jax_path , x .raw_value )
2796- mutable_array = variablelib .mutable_array (x .raw_value )
2797- return x .from_metadata (mutable_array , x .get_metadata ().copy ())
2798- elif isinstance (x , jax .Array ):
2799- check_array (jax_path , x )
2800- return variablelib .mutable_array (x )
2791+ x = jax .tree .map (variablelib .mutable_array , x )
2792+ elif isinstance (x , Variable | VariableState ):
2793+ x = jax .tree .map (lambda x : x , x )
28012794 return x
28022795
28032796 return jax .tree .map_with_path (
2804- _mutable_fn , tree , is_leaf = lambda x : isinstance (x , Variable )
2797+ _mutable_fn , tree , is_leaf = lambda x : isinstance (x , Variable | VariableState )
28052798 )
28062799
28072800
@@ -3047,6 +3040,11 @@ def _key_path_to_key(key: tp.Any) -> Key:
30473040 else :
30483041 return str (key )
30493042
3043+
3044+ def jax_to_nnx_path (jax_path : tuple , / ):
3045+ return tuple (_key_path_to_key (part ) for part in jax_path )
3046+
3047+
30503048class IndexesPytreeDef (tp .NamedTuple ):
30513049 key_index : HashableMapping [Key , int ]
30523050 treedef : jax .tree_util .PyTreeDef
0 commit comments