@@ -43,7 +43,7 @@ def _variable_parents_count(t: type):
4343
4444
4545class NNXMeta (struct .PyTreeNode , meta .AxisMetadata [A ]):
46- """Default Flax metadata class for `nnx.VariableState `."""
46+ """Default Flax metadata class for `nnx.Variable `."""
4747
4848 var_type : type [variablelib .Variable [tp .Any ]] = struct .field (pytree_node = False )
4949 value : Any = struct .field (pytree_node = True )
@@ -65,15 +65,17 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
6565
6666 def get_partition_spec (self ) -> jax .sharding .PartitionSpec :
6767 """Returns the ``Partitionspec`` for this partitioned value."""
68- nnx_var = self .to_nnx_variable ().to_state ()
69- return spmd .get_partition_spec (nnx_var ).raw_value
68+ nnx_var = self .to_nnx_variable ()
69+ spec = spmd .get_partition_spec (nnx_var )
70+ assert isinstance (spec , jax .sharding .PartitionSpec )
71+ return spec
7072
7173 def to_nnx_variable (self ) -> variablelib .Variable :
7274 return self .var_type (self .value , ** self .metadata )
7375
7476
75- def is_vanilla_variable (vs : variablelib .VariableState ) -> bool :
76- """A variables state is vanilla if its metadata is essentially blank.
77+ def is_vanilla_variable (vs : variablelib .Variable ) -> bool :
78+ """A variable is vanilla if its metadata is essentially blank.
7779
7880 Returns False only if it has non-empty hooks or any non-built-in attribute.
7981 """
@@ -86,7 +88,7 @@ def is_vanilla_variable(vs: variablelib.VariableState) -> bool:
8688 return True
8789
8890
89- def to_linen_var (vs : variablelib .VariableState ) -> meta .AxisMetadata :
91+ def to_linen_var (vs : variablelib .Variable ) -> meta .AxisMetadata :
9092 metadata = vs .get_metadata ()
9193 if 'linen_meta_type' in metadata :
9294 linen_type = metadata ['linen_meta_type' ]
@@ -145,14 +147,11 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
145147
146148
147149def nnx_attrs_to_linen_vars (nnx_attrs : dict ) -> dict :
148- """Convert a dict of NNX variables (or variable states) to Linen-style variables."""
150+ """Convert a dict of NNX variables to Linen-style variables."""
149151 linen_structured = {}
150152 for kp , v in traversals .flatten_mapping (nnx_attrs ).items ():
151153 if isinstance (v , variablelib .Variable ):
152154 col_name = variablelib .variable_name_from_type (type (v ))
153- v = to_linen_var (v .to_state ())
154- elif isinstance (v , variablelib .VariableState ):
155- col_name = variablelib .variable_name_from_type (v .type )
156155 v = to_linen_var (v )
157156 elif isinstance (v , graph .GraphDef ):
158157 col_name = 'nnx' # an nnx.GraphDef for some ToLinen submodule
0 commit comments