3636 visualization ,
3737)
3838from flax import config
39- from flax .nnx .variablelib import Variable , is_mutable_array
39+ from flax .nnx .variablelib import Variable
4040from flax .typing import SizeBytes
4141
4242BUILDING_DOCS = 'FLAX_DOC_BUILD' in os .environ
@@ -104,6 +104,7 @@ def __init__(self):
104104 """
105105 return DataAttr (value ) # type: ignore[return-value]
106106
107+
107108def register_data_type (type_ : type , / ) -> None :
108109 """Registers a type as pytree data type recognized by Object.
109110
@@ -264,6 +265,7 @@ def __treescope_repr__(self, path, subtree_renderer):
264265 subtree_renderer = subtree_renderer ,
265266 )
266267
268+
267269def _flatten_object_state (state : ObjectState ):
268270 return (), (state .initializing , state .is_setup )
269271
@@ -279,6 +281,7 @@ def _unflatten_object_state(static: tuple[bool, bool], _):
279281 _unflatten_object_state ,
280282)
281283
284+
282285class ObjectMeta (ABCMeta ):
283286 if not tp .TYPE_CHECKING :
284287
@@ -291,9 +294,18 @@ def _object_meta_construct(cls, self, *args, **kwargs):
291294
292295def _graph_node_meta_call (cls : tp .Type [O ], * args , ** kwargs ) -> O :
293296 node = cls .__new__ (cls , * args , ** kwargs )
294- vars (node )['_object__state' ] = ObjectState ()
295- vars (node )['_object__nodes' ] = cls ._object__nodes
297+ vars_obj = vars (node )
298+ vars_obj ['_object__state' ] = ObjectState ()
299+ vars_obj ['_object__nodes' ] = cls ._object__nodes
296300 cls ._object_meta_construct (node , * args , ** kwargs )
301+ # register possible new data attributes after initialization
302+ for name , value in vars_obj .items ():
303+ if name not in vars_obj ['_object__nodes' ]:
304+ if any (
305+ is_data_type (leaf )
306+ for leaf in jax .tree .leaves (value , is_leaf = is_data_type )
307+ ):
308+ vars_obj ['_object__nodes' ] = vars_obj ['_object__nodes' ].union ((name ,))
297309
298310 return node
299311
@@ -312,6 +324,7 @@ def __nnx_repr__(self):
312324 yield reprlib .Attr ('shape' , self .shape )
313325 yield reprlib .Attr ('dtype' , self .dtype )
314326
327+
315328@dataclasses .dataclass (frozen = True , repr = False )
316329class MutableArrayRepr (reprlib .Representable ):
317330 shape : tp .Tuple [int , ...]
@@ -326,6 +339,7 @@ def __nnx_repr__(self):
326339 yield reprlib .Attr ('shape' , self .shape )
327340 yield reprlib .Attr ('dtype' , self .dtype )
328341
342+
329343class Object (reprlib .Representable , metaclass = ObjectMeta ):
330344 """Base class for all NNX objects."""
331345
@@ -335,7 +349,7 @@ class Object(reprlib.Representable, metaclass=ObjectMeta):
335349 _object__state : ObjectState
336350
337351 def __init_subclass__ (
338- cls , * , pytree : bool = config .flax_mutable_array , ** kwargs
352+ cls , * , pytree : bool = config .flax_pytree_module , ** kwargs
339353 ) -> None :
340354 super ().__init_subclass__ (** kwargs )
341355
@@ -387,20 +401,12 @@ def _setattr(self, name: str, value: tp.Any) -> None:
387401 value = value .value
388402 if name not in self ._object__nodes :
389403 self ._object__nodes = self ._object__nodes .union ((name ,))
390- elif is_data_type (value ):
391- if name not in self ._object__nodes :
392- self ._object__nodes = self ._object__nodes .union ((name ,))
393- elif type (self )._object__is_pytree and name not in self ._object__nodes :
394- for leaf in jax .tree .leaves (value ):
395- if isinstance (leaf , jax .Array ) or is_mutable_array (leaf ):
396- raise TypeError (
397- f"Trying to set '{ name } ' to a value containing one or more "
398- f"jax.Array, but '{ name } ' is not a registered as data. "
399- f"Use 'obj.{ name } = nnx.data(...)' to register the attribute as data "
400- f"on assignment, or add '{ name } : nnx.Data[{ type (value ).__name__ } ]' "
401- f'to the class definition. '
402- f'Got value: { value } '
403- )
404+ # any attribute that contains known data types will be registered as data
405+ elif name not in self ._object__nodes and any (
406+ is_data_type (leaf )
407+ for leaf in jax .tree .leaves (value , is_leaf = is_data_type )
408+ ):
409+ self ._object__nodes = self ._object__nodes .union ((name ,))
404410 object .__setattr__ (self , name , value )
405411
406412 def _check_valid_context (self , error_msg : tp .Callable [[], str ]) -> None :
0 commit comments