Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,5 @@ def __getattr__(name):
)
if name not in globals():
raise AttributeError(f"Module {__name__} has no attribute '{name}'")

return globals()[name]
2 changes: 1 addition & 1 deletion flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata:
return linen_type(vs.value, **metadata)
if is_vanilla_variable(vs):
return vs.value
return NNXMeta(vs.type, vs.value, metadata)
return NNXMeta(type(vs), vs.value, metadata)


def get_col_name(keypath: tp.Sequence[Any]) -> str:
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _update_variables(self, module):

# group state by collection
for path, leaf in nnx.to_flat_state(state):
type_ = leaf.type if isinstance(leaf, nnx.Variable) else type(leaf)
type_ = type(leaf)
collection = variablelib.variable_name_from_type(
type_, allow_register=True
)
Expand Down
12 changes: 6 additions & 6 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,9 @@ def f(m: Dict):
assert m.a[0] is m.b
assert isinstance(grads, nnx.State)
assert grads['a']['0'].value == 2.0
assert issubclass(grads['a']['0'].type, nnx.Variable)
assert issubclass(type(grads['a']['0']), nnx.Variable)
assert grads['a']['1'].value == 1.0
assert issubclass(grads['a']['1'].type, nnx.Variable)
assert issubclass(type(grads['a']['1']), nnx.Variable)
assert len(nnx.to_flat_state(grads)) == 2

nnx.update(m, grads)
Expand Down Expand Up @@ -659,7 +659,7 @@ def f(m: Dict):

assert isinstance(grads, nnx.State)
assert grads['a']['0'].value == 1.0
assert issubclass(grads['a']['0'].type, nnx.Param)
assert issubclass(type(grads['a']['0']), nnx.Param)
assert len(grads) == 2

nnx.update(m, grads)
Expand Down Expand Up @@ -687,7 +687,7 @@ def f(m: Dict):

assert isinstance(grads, nnx.State)
assert grads['a']['1'].value == 1.0
assert issubclass(grads['a']['1'].type, nnx.BatchStat)
assert issubclass(type(grads['a']['1']), nnx.BatchStat)
assert len(grads) == 1

nnx.update(m, grads)
Expand Down Expand Up @@ -843,9 +843,9 @@ def f(m: dict):

assert m['a'][0] is m['b']
assert isinstance(grads, dict)
assert issubclass(grads['a'][0].type, nnx.Variable)
assert issubclass(type(grads['a'][0]), nnx.Variable)
assert grads['a'][1].value == 1.0
assert issubclass(grads['a'][1].type, nnx.Variable)
assert issubclass(type(grads['a'][1]), nnx.Variable)
assert len(jax.tree.leaves(grads)) == 2

jax.tree.map(
Expand Down
Loading