2121from collections import namedtuple
2222from contextlib import contextmanager
2323from copy import deepcopy
24+ from enum import Enum
2425from functools import reduce
2526from typing import TYPE_CHECKING , Any , Callable , Tuple , Union
2627
4950from ...symbolic_shape .operators import SYMBOLIC_BINARY_OPS , SYMBOLIC_UNARY_OPS
5051from ...utils import (
5152 ENV_SOT_ALLOW_DYNAMIC_SHAPE ,
53+ NUMPY_API_SUPPORTED_DICT ,
5254 NameGenerator ,
5355 SIRToCodeMap ,
5456 SotUndefinedVar ,
8688 GlobalVariable ,
8789 ListVariable ,
8890 NullVariable ,
91+ NumpyArrayVariable ,
8992 PaddleLayerVariable ,
9093 ParameterVariable ,
9194 SymbolicVariable ,
99102if TYPE_CHECKING :
100103 import types
101104
105+ GraphNodeVariableType : TypeAlias = Union [
106+ TensorVariable , SymbolicVariable , NumpyArrayVariable
107+ ]
108+
102109
103110CompileGraphResult : TypeAlias = Tuple [
104111 Callable [..., Any ],
108115 OrderedSet [Union [TensorVariable , SymbolicVariable ]],
109116 ],
110117]
118+ GraphNodeVariableClasses = (
119+ TensorVariable ,
120+ SymbolicVariable ,
121+ NumpyArrayVariable ,
122+ )
111123
112124
113125def convert_to_meta (inputs : Any ):
@@ -116,7 +128,7 @@ def convert_to_meta(inputs: Any):
116128 """
117129
118130 def func (x ):
119- if isinstance (x , ( TensorVariable , SymbolicVariable ) ):
131+ if isinstance (x , GraphNodeVariableClasses ):
120132 return x .meta
121133 if isinstance (x , VariableBase ):
122134 return x .get_py_value ()
@@ -131,7 +143,7 @@ def convert_to_symbol(inputs: Any):
131143 """
132144
133145 def func (x ):
134- if isinstance (x , ( TensorVariable , SymbolicVariable ) ):
146+ if isinstance (x , GraphNodeVariableClasses ):
135147 return x .get_symbol ()
136148 if isinstance (x , VariableBase ):
137149 return x .get_py_value ()
@@ -155,7 +167,7 @@ def record_symbols(SIR, *args, **kwargs):
155167 non_params = set ()
156168
157169 def fn (value ):
158- if isinstance (value , ( TensorVariable , SymbolicVariable ) ):
170+ if isinstance (value , GraphNodeVariableClasses ):
159171 symbol_meta_map [value .get_symbol ()] = value .meta
160172 if isinstance (value , ParameterVariable ):
161173 params .add (value .get_symbol ())
@@ -190,6 +202,12 @@ def func(x):
190202 return map_variables (func , inputs , restore_variable = True )
191203
192204
205+ class APIType (Enum ):
206+ PADDLE = 0
207+ SYMBOLIC = 1
208+ NUMPY = 2
209+
210+
193211class VariableLoader :
194212 def __init__ (self , store_var_info , pycode_gen ):
195213 self ._store_var_info = store_var_info
@@ -541,7 +559,34 @@ def message_handler(*args, **kwargs):
541559 InferMetaCache (),
542560 self .sir_builder .call_API ,
543561 func ,
544- False ,
562+ APIType .PADDLE ,
563+ * args ,
564+ ** kwargs ,
565+ )
566+
567+ def call_numpy_api (
568+ self ,
569+ func : Callable [..., Any ],
570+ * args : VariableBase ,
571+ ** kwargs : VariableBase ,
572+ ):
573+ """
574+ Record Numpy API to SIR
575+
576+ Args:
577+ func: numpy api
578+ """
579+ assert func in NUMPY_API_SUPPORTED_DICT .values ()
580+ log (3 , f"call numpy.api : { func .__name__ } " , "\n " )
581+
582+ def message_handler (* args , ** kwargs ):
583+ return f"Call numpy api error: { func .__name__ } , may be not a operator api?"
584+
585+ return inner_error_default_handler (self .symbolic_call , message_handler )(
586+ InferMetaCache (),
587+ self .sir_builder .call_API ,
588+ func ,
589+ APIType .NUMPY ,
545590 * args ,
546591 ** kwargs ,
547592 )
@@ -562,7 +607,7 @@ def message_handler(*args, **kwargs):
562607 InferMetaCache (),
563608 self .sir_builder .call_API ,
564609 op ,
565- True ,
610+ APIType . SYMBOLIC ,
566611 * args ,
567612 ** kwargs ,
568613 )
@@ -584,7 +629,7 @@ def message_handler(*args, **kwargs):
584629 InferMetaCache (),
585630 self .sir_builder .call_METHOD ,
586631 method_name ,
587- False ,
632+ APIType . PADDLE ,
588633 * args ,
589634 ** kwargs ,
590635 )
@@ -619,7 +664,7 @@ def message_handler(*args, **kwargs):
619664 return f"Call paddle layer error: { layer } , may be not a valid paddle layer?"
620665
621666 return inner_error_default_handler (self .symbolic_call , message_handler )(
622- infer_meta_fn , compute_fn , layer , False , * args , ** kwargs
667+ infer_meta_fn , compute_fn , layer , APIType . PADDLE , * args , ** kwargs
623668 )
624669
625670 def call_ast (
@@ -653,7 +698,7 @@ def message_handler(*args, **kwargs):
653698 ast_infer_meta ,
654699 compute_fn ,
655700 static_function ,
656- False ,
701+ APIType . PADDLE ,
657702 * args ,
658703 ** kwargs ,
659704 )
@@ -662,7 +707,7 @@ def message_handler(*args, **kwargs):
662707 return None
663708
664709 def symbolic_call (
665- self , infer_meta_fn , compute_fn , func , is_symbolic_var , * args , ** kwargs
710+ self , infer_meta_fn , compute_fn , func , api_type , * args , ** kwargs
666711 ):
667712 """
668713 Using infer_meta_fn and compute_fn convert func to symbolic function.
@@ -763,11 +808,14 @@ def try_infer_meta_fn(args, kwargs) -> Any:
763808
764809 log (3 , f" inputs : { inputs_symbols } " , "\n " )
765810
766- if is_symbolic_var :
811+ if api_type == APIType . SYMBOLIC :
767812 var_cls = SymbolicVariable
768813 tracker = SymbolicOperationTracker (
769814 list (args ) + list (kwargs .values ()), func
770815 )
816+ elif api_type == APIType .NUMPY :
817+ var_cls = NumpyArrayVariable
818+ tracker = DummyTracker (list (args ) + list (kwargs .values ()))
771819 else :
772820 var_cls = TensorVariable
773821 tracker = DummyTracker (list (args ) + list (kwargs .values ()))
@@ -807,7 +855,7 @@ def try_infer_meta_fn(args, kwargs) -> Any:
807855 stmt_stacks ,
808856 ) # symbolic only contain symbols.
809857 self ._put_inner (outputs )
810- if is_symbolic_var :
858+ if api_type == APIType . SYMBOLIC :
811859 # compute_fn should be call_method
812860 tracker = SymbolicOperationTracker (
813861 list (args ) + list (kwargs .values ()), func
@@ -892,13 +940,13 @@ def remove_global_guarded_variable(self, variable: VariableBase):
892940
893941 def _find_tensor_inputs (
894942 self , input_names : list [str ]
895- ) -> OrderedSet [TensorVariable | SymbolicVariable ]:
896- inputs : OrderedSet [TensorVariable | SymbolicVariable ] = OrderedSet ()
943+ ) -> OrderedSet [GraphNodeVariableType ]:
944+ inputs : OrderedSet [GraphNodeVariableType ] = OrderedSet ()
897945 for name in input_names :
898946 found = False
899947 for variable in self .input_variables :
900948 if (
901- isinstance (variable , ( TensorVariable , SymbolicVariable ) )
949+ isinstance (variable , GraphNodeVariableClasses )
902950 and variable .get_symbol ().name == name
903951 ):
904952 inputs .add (variable )
@@ -908,30 +956,37 @@ def _find_tensor_inputs(
908956 assert len (inputs ) == len (input_names ), "Number of inputs not match."
909957 return inputs
910958
911- def gen_load_inputs (
912- self , inputs : OrderedSet [TensorVariable | SymbolicVariable ]
913- ):
959+ def gen_load_inputs (self , inputs : OrderedSet [GraphNodeVariableType ]):
914960 for input_var in inputs :
915- # For SymbolicVariable, we use paddle.full([], value, "int64")
916- # to convert it to a Tensor
917961 if isinstance (input_var , SymbolicVariable ):
962+ # For SymbolicVariable, we use paddle.full([], value, "int64")
963+ # to convert it to a Tensor
918964 self .pycode_gen .gen_load_object (
919965 paddle .full ,
920966 "___paddle_full" ,
921967 )
922968 self .pycode_gen .gen_build_list (0 )
923- input_var .tracker .gen_instructions (self .pycode_gen )
924- if isinstance (input_var , SymbolicVariable ):
969+ input_var .tracker .gen_instructions (self .pycode_gen )
925970 self .pycode_gen .gen_load_const ("int64" )
926971 self .pycode_gen .gen_call_function (3 )
972+ elif isinstance (input_var , NumpyArrayVariable ):
973+ # For NumpyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
974+ self .pycode_gen .gen_load_object (
975+ paddle .to_tensor ,
976+ "___paddle_to_tensor" ,
977+ )
978+ input_var .tracker .gen_instructions (self .pycode_gen )
979+ self .pycode_gen .gen_call_function (1 )
980+ else :
981+ input_var .tracker .gen_instructions (self .pycode_gen )
927982
928983 @staticmethod
929984 def _is_graph_output (
930985 var ,
931- ) -> TypeGuard [TensorVariable | SymbolicVariable ]:
986+ ) -> TypeGuard [GraphNodeVariableType ]:
932987 return isinstance (
933988 var .tracker , (DummyTracker , SymbolicOperationTracker )
934- ) and isinstance (var , ( TensorVariable , SymbolicVariable ) )
989+ ) and isinstance (var , GraphNodeVariableClasses )
935990
936991 @staticmethod
937992 def _collect_related_dummy_tensor (var ):
@@ -949,17 +1004,15 @@ def _collect_related_dummy_tensor(var):
9491004
9501005 def _find_tensor_outputs (
9511006 self , outputs : list [VariableBase ]
952- ) -> OrderedSet [TensorVariable | SymbolicVariable ]:
1007+ ) -> OrderedSet [GraphNodeVariableType ]:
9531008 """
9541009 Return all TensorVariable. find TensorVariables participating in networking from the output Variables
9551010
9561011 Args:
9571012 outputs: output variables
9581013 """
9591014
960- output_tensors : OrderedSet [TensorVariable | SymbolicVariable ] = (
961- OrderedSet ()
962- )
1015+ output_tensors : OrderedSet [GraphNodeVariableType ] = OrderedSet ()
9631016 # Find Tensor Variables from outputs.
9641017 for output in outputs :
9651018 if isinstance (
0 commit comments