@@ -628,9 +628,9 @@ def f(m: Dict):
628628 assert m .a [0 ] is m .b
629629 assert isinstance (grads , nnx .State )
630630 assert grads ['a' ]['0' ].value == 2.0
631- assert issubclass (grads ['a' ]['0' ]. type , nnx .Variable )
631+ assert issubclass (type ( grads ['a' ]['0' ]) , nnx .Variable )
632632 assert grads ['a' ]['1' ].value == 1.0
633- assert issubclass (grads ['a' ]['1' ]. type , nnx .Variable )
633+ assert issubclass (type ( grads ['a' ]['1' ]) , nnx .Variable )
634634 assert len (nnx .to_flat_state (grads )) == 2
635635
636636 nnx .update (m , grads )
@@ -659,7 +659,7 @@ def f(m: Dict):
659659
660660 assert isinstance (grads , nnx .State )
661661 assert grads ['a' ]['0' ].value == 1.0
662- assert issubclass (grads ['a' ]['0' ]. type , nnx .Param )
662+ assert issubclass (type ( grads ['a' ]['0' ]) , nnx .Param )
663663 assert len (grads ) == 2
664664
665665 nnx .update (m , grads )
@@ -687,7 +687,7 @@ def f(m: Dict):
687687
688688 assert isinstance (grads , nnx .State )
689689 assert grads ['a' ]['1' ].value == 1.0
690- assert issubclass (grads ['a' ]['1' ]. type , nnx .BatchStat )
690+ assert issubclass (type ( grads ['a' ]['1' ]) , nnx .BatchStat )
691691 assert len (grads ) == 1
692692
693693 nnx .update (m , grads )
@@ -843,9 +843,9 @@ def f(m: dict):
843843
844844 assert m ['a' ][0 ] is m ['b' ]
845845 assert isinstance (grads , dict )
846- assert issubclass (grads ['a' ][0 ]. type , nnx .Variable )
846+ assert issubclass (type ( grads ['a' ][0 ]) , nnx .Variable )
847847 assert grads ['a' ][1 ].value == 1.0
848- assert issubclass (grads ['a' ][1 ]. type , nnx .Variable )
848+ assert issubclass (type ( grads ['a' ][1 ]) , nnx .Variable )
849849 assert len (jax .tree .leaves (grads )) == 2
850850
851851 jax .tree .map (
0 commit comments