Skip to content

Commit 539d9f0

Browse files
committed
Fixed .type deprecated property usage
1 parent f179e8e commit 539d9f0

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

flax/nnx/bridge/variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata:
9797
return linen_type(vs.value, **metadata)
9898
if is_vanilla_variable(vs):
9999
return vs.value
100-
return NNXMeta(vs.type, vs.value, metadata)
100+
return NNXMeta(type(vs), vs.value, metadata)
101101

102102

103103
def get_col_name(keypath: tp.Sequence[Any]) -> str:

flax/nnx/bridge/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def _update_variables(self, module):
349349

350350
# group state by collection
351351
for path, leaf in nnx.to_flat_state(state):
352-
type_ = leaf.type if isinstance(leaf, nnx.Variable) else type(leaf)
352+
type_ = type(leaf)
353353
collection = variablelib.variable_name_from_type(
354354
type_, allow_register=True
355355
)

tests/nnx/transforms_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -628,9 +628,9 @@ def f(m: Dict):
628628
assert m.a[0] is m.b
629629
assert isinstance(grads, nnx.State)
630630
assert grads['a']['0'].value == 2.0
631-
assert issubclass(grads['a']['0'].type, nnx.Variable)
631+
assert issubclass(type(grads['a']['0']), nnx.Variable)
632632
assert grads['a']['1'].value == 1.0
633-
assert issubclass(grads['a']['1'].type, nnx.Variable)
633+
assert issubclass(type(grads['a']['1']), nnx.Variable)
634634
assert len(nnx.to_flat_state(grads)) == 2
635635

636636
nnx.update(m, grads)
@@ -659,7 +659,7 @@ def f(m: Dict):
659659

660660
assert isinstance(grads, nnx.State)
661661
assert grads['a']['0'].value == 1.0
662-
assert issubclass(grads['a']['0'].type, nnx.Param)
662+
assert issubclass(type(grads['a']['0']), nnx.Param)
663663
assert len(grads) == 2
664664

665665
nnx.update(m, grads)
@@ -687,7 +687,7 @@ def f(m: Dict):
687687

688688
assert isinstance(grads, nnx.State)
689689
assert grads['a']['1'].value == 1.0
690-
assert issubclass(grads['a']['1'].type, nnx.BatchStat)
690+
assert issubclass(type(grads['a']['1']), nnx.BatchStat)
691691
assert len(grads) == 1
692692

693693
nnx.update(m, grads)
@@ -843,9 +843,9 @@ def f(m: dict):
843843

844844
assert m['a'][0] is m['b']
845845
assert isinstance(grads, dict)
846-
assert issubclass(grads['a'][0].type, nnx.Variable)
846+
assert issubclass(type(grads['a'][0]), nnx.Variable)
847847
assert grads['a'][1].value == 1.0
848-
assert issubclass(grads['a'][1].type, nnx.Variable)
848+
assert issubclass(type(grads['a'][1]), nnx.Variable)
849849
assert len(jax.tree.leaves(grads)) == 2
850850

851851
jax.tree.map(

0 commit comments

Comments
 (0)