@@ -268,7 +268,7 @@ def __treescope_repr__(self, path, subtree_renderer):
268268@dataclasses .dataclass (frozen = True , repr = False )
269269class VariableDef (reprlib .Representable , tp .Generic [Node ]):
270270 type : type [Node ]
271- index : int
271+ index : int # TODO(cgarciae): make Optional instead of using -1
272272 outer_index : int | None
273273 metadata : HashableMapping [str , tp .Any ]
274274
@@ -320,7 +320,11 @@ class NodeDef(tp.Generic[Node], reprlib.Representable):
320320 attributes : tuple [
321321 tuple [
322322 Key ,
323- NodeDef [tp .Any ] | VariableDef [tp .Any ] | NodeRef [tp .Any ] | Static [tp .Any ],
323+ NodeDef [tp .Any ]
324+ | VariableDef [tp .Any ]
325+ | NodeRef [tp .Any ]
326+ | Static [tp .Any ]
327+ | ArrayAttr ,
324328 ],
325329 ...,
326330 ]
@@ -387,6 +391,7 @@ def __treescope_repr__(self, path, subtree_renderer):
387391 subtree_renderer = subtree_renderer ,
388392 )
389393
394+ # TODO(cgarciae): remove this method
390395 def apply (
391396 self , state : GraphState , * states : GraphState
392397 ) -> ApplyCaller [tuple [GraphDef [Node ], GraphState ]]:
@@ -407,10 +412,16 @@ def _apply(
407412
408413jax .tree_util .register_static (NodeDef )
409414
415+ @dataclasses .dataclass (frozen = True , slots = True )
416+ class ArrayAttr :
417+ pass
418+
419+
420+ ARRAY_ATTR = ArrayAttr ()
421+
410422GraphDef = tp .Union [NodeDef [Node ], NodeRef [Node ], VariableDef [Node ]]
411423PureState = tuple [GraphDef [Node ], GraphState ]
412424
413-
414425@tp .overload
415426def flatten (
416427 node : Node ,
@@ -494,7 +505,7 @@ def flatten(
494505 if ref_index is None :
495506 ref_index = RefMap ()
496507
497- leaves : list [StateLeaf | Variable [tp .Any ]] = []
508+ leaves : list [StateLeaf | Variable [tp .Any ] | jax . Array | np . ndarray ] = []
498509 path : list [Key ] | None = [] if with_paths else None
499510 paths : list [PathParts ] | None = [] if with_paths else None
500511 node_impl = get_node_impl (node )
@@ -523,7 +534,7 @@ def _graph_flatten(
523534 path : list [Key ] | None ,
524535 ref_index : RefMap ,
525536 ref_outer_index : RefMap | None ,
526- leaves : list [StateLeaf | Variable [tp .Any ]],
537+ leaves : list [StateLeaf | Variable [tp .Any ] | jax . Array | np . ndarray ],
527538 paths : list [PathParts ] | None ,
528539 return_variables : bool ,
529540) -> NodeDef | NodeRef | VariableDef :
@@ -539,6 +550,7 @@ def _graph_flatten(
539550 index = len (ref_index )
540551 ref_index [node ] = index
541552 else :
553+ # TODO(cgarciae): use None instead of -1
542554 index = - 1
543555
544556 if is_variable :
@@ -565,7 +577,14 @@ def _graph_flatten(
565577 raise RuntimeError (f'Unsupported type: { type (node )} , this is a bug.' )
566578
567579 attributes : list [
568- tuple [Key , Static [tp .Any ] | NodeDef [tp .Any ] | VariableDef | NodeRef [tp .Any ]]
580+ tuple [
581+ Key ,
582+ Static [tp .Any ]
583+ | ArrayAttr
584+ | NodeDef [tp .Any ]
585+ | VariableDef
586+ | NodeRef [tp .Any ],
587+ ]
569588 ] = []
570589
571590 values , metadata = node_impl .flatten (node )
@@ -585,16 +604,12 @@ def _graph_flatten(
585604 return_variables ,
586605 )
587606 attributes .append ((key , nodedef ))
607+ elif isinstance (value , (jax .Array , np .ndarray )):
608+ if paths is not None :
609+ paths .append (tuple (path )) # type: ignore
610+ attributes .append ((key , ARRAY_ATTR ))
611+ leaves .append (value )
588612 else :
589- if isinstance (value , (jax .Array , np .ndarray )):
590- if path is not None :
591- path_str = '/' .join (map (str , path ))
592- raise ValueError (
593- f'Arrays leaves are not supported, at { path_str !r} : { value } '
594- )
595- else :
596- raise ValueError (f'Arrays leaves are not supported, found { value } ' )
597- # static_fields.append((key, value))
598613 attributes .append ((key , Static (value )))
599614
600615 if path is not None :
@@ -695,9 +710,7 @@ def _graph_fingerprint(
695710 append_fn (variable_index )
696711 for key_value in value ._var_metadata .items ():
697712 append_fn (key_value )
698- else :
699- if isinstance (value , (jax .Array , np .ndarray )):
700- raise ValueError (f'Arrays leaves are not supported: { value } ' )
713+ elif not isinstance (value , (jax .Array , np .ndarray )):
701714 append_fn (value )
702715
703716
@@ -961,6 +974,11 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
961974 for key , value in nodedef .attributes :
962975 if type (value ) is Static :
963976 children .append ((key , value .value ))
977+ elif type (value ) is ArrayAttr :
978+ if not leaves :
979+ raise ValueError ('Not enough leaves to unflatten the graph' )
980+ array = leaves .popleft ()
981+ children .append ((key , array ))
964982 elif type (value ) is NodeRef :
965983 children .append ((key , index_ref [value .index ]))
966984 elif type (value ) is NodeDef :
@@ -1126,8 +1144,16 @@ def _update_variable(node: Variable, value):
11261144 raise ValueError (f'Expected a subgraph for { key !r} , but got: { value !r} ' )
11271145 _graph_update_dynamic (current_value , value )
11281146 else :
1129- # case 3: state leaf is being updated
1130- if not isinstance (current_value , Variable ):
1147+ if isinstance (current_value , jax .Array | np .ndarray ):
1148+ if isinstance (node_impl , PytreeNodeImpl ):
1149+ raise ValueError (
1150+ f'Cannot set key { key !r} on immutable node of '
1151+ f'type { type (node ).__name__ } '
1152+ )
1153+ node_impl .set_key (node , key , value )
1154+ continue
1155+ elif not isinstance (current_value , Variable ):
1156+ # case 3: state leaf is being updated
11311157 raise ValueError (
11321158 f'Trying to update a non-Variable attribute { key !r} with a Variable: '
11331159 f'{ value !r} '
@@ -1255,7 +1281,8 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args):
12551281 cached_ref_index : RefMap = RefMap ()
12561282
12571283 def create_static_cache (x ):
1258- if is_graph_node (x ):
1284+ # TODO(cgarciae): support Array attribute updates for graph nodes
1285+ if is_graph_node (x ) or isinstance (x , Variable ):
12591286 graphdef , flat_state = flatten (
12601287 x , with_paths = True , return_variables = True , ref_index = original_ref_index
12611288 )
0 commit comments