diff --git a/python/paddle/jit/sot/infer_meta.py b/python/paddle/jit/sot/infer_meta.py index ea2cbff8b1cc1b..7f90468bdf4b09 100644 --- a/python/paddle/jit/sot/infer_meta.py +++ b/python/paddle/jit/sot/infer_meta.py @@ -261,14 +261,39 @@ def infer_meta_for_layer(layer, *args, **kwargs): ) = layer.forward.get_concrete_program(*args_, **kwargs_) out = partial_program_layer._restore_out( - paddle.utils.flatten( - convert_variable_to_meta_info(concrete_program.outputs) - ) + [ + x + for x in paddle.utils.flatten( + convert_variable_to_meta_info(concrete_program.outputs) + ) + if isinstance(x, MetaInfo) + ] ) layer.forward.rollback() return out +def ast_infer_meta(static_function, *args, **kwargs): + args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) + + ( + concrete_program, + partial_program_layer, + ) = static_function.get_concrete_program(*args_, **kwargs_) + + out = partial_program_layer._restore_out( + [ + x + for x in paddle.utils.flatten( + convert_variable_to_meta_info(concrete_program.outputs) + ) + if isinstance(x, MetaInfo) + ] + ) + + return out + + @Singleton class SpecialInferMeta: """ diff --git a/python/paddle/jit/sot/opcode_translator/__init__.py b/python/paddle/jit/sot/opcode_translator/__init__.py index 64fda66a2747d8..dec41c8bba1721 100644 --- a/python/paddle/jit/sot/opcode_translator/__init__.py +++ b/python/paddle/jit/sot/opcode_translator/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. from .skip_files import setup_skip_files -from .transform import eval_frame_callback # noqa: F401 +from .eval_frame_callback import eval_frame_callback # noqa: F401 setup_skip_files() diff --git a/python/paddle/jit/sot/opcode_translator/transform.py b/python/paddle/jit/sot/opcode_translator/eval_frame_callback.py similarity index 90% rename from python/paddle/jit/sot/opcode_translator/transform.py rename to python/paddle/jit/sot/opcode_translator/eval_frame_callback.py index 4f6ad8e43e90c5..d454bb43aa035e 100644 --- a/python/paddle/jit/sot/opcode_translator/transform.py +++ b/python/paddle/jit/sot/opcode_translator/eval_frame_callback.py @@ -58,7 +58,9 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode: ) log_do(4, partial(print_locals, frame)) - log_format(3, "[transform] OriginCode: {}\n", frame.f_code.co_name) + log_format( + 3, "[eval_frame_callback] OriginCode: {}\n", frame.f_code.co_name + ) log_do(3, lambda: dis.dis(frame.f_code)) custom_code = OpcodeExecutorCache()(frame, **kwargs) @@ -66,13 +68,13 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode: if custom_code.code is None: log_format( 3, - "[transform] NewCode (same as origin code): {}\n", + "[eval_frame_callback] NewCode (same as origin code): {}\n", frame.f_code.co_name, ) else: log_format( 3, - "[transform] NewCode: {}\n", + "[eval_frame_callback] NewCode: {}\n", custom_code.code.co_name, ) log_do(3, lambda: dis.dis(custom_code.code)) diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index a2fb85734c7be4..9135e0cbdfa604 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -24,7 +24,12 @@ from functools import cached_property from typing import Any, Callable -from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo +from ...infer_meta import ( + InferMetaCache, + LayerInferMetaCache, + MetaInfo, + ast_infer_meta, +) from ...profiler import EventGuard, event_register from ...symbolic.statement_ir import Reference, Symbol from ...symbolic.symbolic_context import SymbolicTraceContext @@ -56,6 +61,7 @@ ) from .tracker import BuiltinTracker, DummyTracker from .variables import ( + ConstantVariable, DictVariable, GlobalVariable, ListVariable, @@ -99,6 +105,18 @@ def func(x): return map_variables(func, inputs) +def get_symbol_meta_map(inputs): + output = {} + + def func(x): + if isinstance(x, TensorVariable): + output[x.get_symbol()] = x.meta + return x + + map_variables(func, inputs) + return output + + class FunctionGraph: """ A Graph representation corresponding to each FunctionFrame @@ -129,7 +147,6 @@ def __init__(self, frame, **kwargs): self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet() self._print_variables = [] self._inplace_tensors = OrderedSet() - self.build_strategy = kwargs.get('build_strategy', None) self._kwargs = kwargs @cached_property @@ -291,7 +308,7 @@ def load(self, var): return VariableLoader(store_var_info, self.pycode_gen) - def _build_compile_fn_with_name_store(self, ret_vars, to_store_vars): + def _build_compile_fn_with_name_store(self, to_store_vars): class VariableLoader: def __init__(self, index_for_load, pycode_gen): self._index_for_load = index_for_load @@ -308,7 +325,7 @@ def load(self, var, allow_push_null=True): to_store_vars = list( filter(lambda x: not isinstance(x, NullVariable), to_store_vars) ) - self.start_compile(*(ret_vars + to_store_vars)) + self.start_compile(*to_store_vars) name_gen = NameGenerator("__start_compile_saved_") for var in to_store_vars[::-1]: @@ -326,6 +343,22 @@ def _log_fn(): return VariableLoader(index_for_load, self.pycode_gen) + def get_compiled_fn(self, *ret_vars): + ret_items = [ + ret_item + for ret_var in ret_vars + for ret_item in ret_var.flatten_items() + ] + + tensor_items = self._find_tensor_outputs(ret_items) + + compiled_fn, _ = self.sir_ctx.compile_fn( + [Symbol(tensor_var.var_name) for tensor_var in tensor_items], + **self._kwargs, + ) + + return compiled_fn + @event_register("start_compile", event_level=2) def start_compile(self, *ret_vars: VariableBase): """ @@ -440,36 +473,6 @@ def message_handler(*args, **kwargs): **kwargs, ) - @staticmethod - def get_opcode_executor_stack(): - # NOTE: only for debug. - # dependent on OpcodeExecutor. - from .opcode_executor import OpcodeExecutorBase - - if len(OpcodeExecutorBase.call_stack) == 0: - # In test case, we can meet this senario. - return [] - current_executor = OpcodeExecutorBase.call_stack[-1] - current_line = current_executor._current_line - filename = current_executor._code.co_filename - source_lines, start_line = inspect.getsourcelines( - current_executor._code - ) - # TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph, - # We need to find a way to fix this. - line_idx = min(current_line - start_line, len(source_lines) - 1) - code_line = source_lines[line_idx] - stack = [] - stack.append( - ' File "{}", line {}, in {}'.format( - filename, - current_line, - current_executor._code.co_name, - ) - ) - stack.append(f' {code_line}') - return stack - def call_layer( self, layer: PaddleLayerVariable, @@ -503,14 +506,46 @@ def message_handler(*args, **kwargs): infer_meta_fn, compute_fn, layer, *args, **kwargs ) + def call_ast( + self, + static_function: tuple, + *args: VariableBase, + **kwargs: VariableBase, + ): + """ + call paddle layer, start symbolic trace. + + Args: + layer: paddle layer + """ + + def compute_fn(static_function, inputs, outputs, stacks): + self.sir_ctx.call_AST( + static_function, + inputs=inputs, + outputs=outputs, + stacks=stacks, + ) + + def message_handler(*args, **kwargs): + return "Call ast faild" + + try: + return inner_error_default_handler( + self.symbolic_call, message_handler + )(ast_infer_meta, compute_fn, static_function, *args, **kwargs) + except Exception as e: + log(3, f"[call AST] {e}") + return None + def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs): """ Using infer_meta_fn and compute_fn convert func to symbolic function. Args: infer_meta_fn: function for infer meta, (func, metas, kwmetas) -> output_metas - compute_fn : function for sir compile, (func, input_symbols, outputs_symbols) -> None - func : symbolic function + compute_fn : function for add stmt to sir, (func, input_symbols, outputs_symbols, stacks) -> None + func : the logical function which will be represent as a stmt """ self.collect_input_variables(list(args)) self.collect_input_variables(list(kwargs.values())) @@ -522,6 +557,10 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs): convert_to_symbol(args), convert_to_symbol(kwargs), ) + + self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(args)) + self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(kwargs)) + log(3, f" inputs : {inputs_symbols}", "\n") outputs = map_if( @@ -564,7 +603,37 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs): outputs, self, DummyTracker(list(args) + list(kwargs.values())) ) else: - return None + return ConstantVariable.wrap_literal(None, self) + + @staticmethod + def get_opcode_executor_stack(): + # NOTE: only for debug. + # dependent on OpcodeExecutor. + from .opcode_executor import OpcodeExecutorBase + + if len(OpcodeExecutorBase.call_stack) == 0: + # In test case, we can meet this senario. + return [] + current_executor = OpcodeExecutorBase.call_stack[-1] + current_line = current_executor._current_line + filename = current_executor._code.co_filename + source_lines, start_line = inspect.getsourcelines( + current_executor._code + ) + # TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph, + # We need to find a way to fix this. + line_idx = max(min(current_line - start_line, len(source_lines) - 1), 0) + code_line = source_lines[line_idx] + stack = [] + stack.append( + ' File "{}", line {}, in {}'.format( + filename, + current_line, + current_executor._code.co_name, + ) + ) + stack.append(f' {code_line}') + return stack def _put_inner(self, vars: VariableBase): """ diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index fb3c49b526577d..da9c21eeabe61c 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -36,6 +36,7 @@ InnerError, OrderedSet, SotUndefinedVar, + get_static_function, log, log_do, ) @@ -180,7 +181,7 @@ def pop_jump_if_op_wrapper(fns: list[Callable[[Any], Any]]): """ - @jump_break_graph_decorator + @if_break_graph_decorator def inner(self: OpcodeExecutorBase, instr: Instruction): """ Inner function that represents the wrapped POP_JUMP_IF opcode operation. @@ -214,7 +215,7 @@ def inner(self: OpcodeExecutorBase, instr: Instruction): return inner -def jump_break_graph_decorator(normal_jump: Callable): +def if_break_graph_decorator(normal_jump: Callable): """ A decorator function that breaks off the graph when a JUMP-related instruction is encountered. @@ -231,8 +232,8 @@ def inner(self: OpcodeExecutor, instr: Instruction): if isinstance(result, TensorVariable): # fallback when in OpcodeExecutor # raise error in OpcodeInlineExecutor - log(3, "[BreakGraph] jump break graph, because if tensor\n") - self._break_graph_in_jump(result, instr) + log(3, "[BreakGraph] break graph for if jump tensor\n") + self._break_graph_when_if(result, instr) return Stop(state="BreakGraph") else: return normal_jump(self, instr) @@ -265,7 +266,7 @@ def wrapper(self: OpcodeExecutor, instr: Instruction): ) if isinstance(self, OpcodeExecutor): log(3, f"[BreakGraph] call function Break graph: {e}\n") - self._break_graph_in_call(origin_stack, instr, push_n) + self._break_graph_when_call(origin_stack, instr, push_n) return Stop(state="BreakGraph") else: raise e @@ -390,7 +391,7 @@ def _prepare_virtual_env(self): """ raise NotImplementedError("Please implement virtual_env.") - def _break_graph_in_jump(self, result, instr: Instruction): + def _break_graph_when_if(self, result, instr: Instruction): """ Breaks the graph in JUMP instructions. @@ -512,7 +513,7 @@ def run(self): Executes the opcode. """ - log(3, f"start execute opcode: {self._code}\n") + log(3, f"[EXECUTOR RUN] Start execute opcode: {self._code}\n") self._lasti = 0 while True: if self._lasti >= len(self._instructions): @@ -524,6 +525,7 @@ def run(self): self.stop_state = is_stop.state self.pop_call_stack_until_self() break + log(3, f"[EXECUTOR RUN] End execute opcode: {self._code}\n") def step(self, instr: Instruction): """ @@ -1255,7 +1257,7 @@ def CONTAINS_OP(self, instr: Instruction): )(left, right) ) - @jump_break_graph_decorator + @if_break_graph_decorator def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): pred_obj = self.stack.top if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): @@ -1271,7 +1273,7 @@ def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): "Currently don't support predicate a non-const / non-tensor obj." ) - @jump_break_graph_decorator + @if_break_graph_decorator def JUMP_IF_TRUE_OR_POP(self, instr: Instruction): pred_obj = self.stack.top if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): @@ -1535,22 +1537,25 @@ def gen_compute_in_break_with_name_store(self, restore_names, instr_idx): instr_idx: the index for branch 1 to find the boundary and copy origin opcode """ - if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): - store_var_info = {} - for name in restore_names: - _var = self.get_var(name) - if _var not in self.stack: - store_var_info[_var.id] = name + # if we want get compiled fn, and do not do ast twice, + # we must give retval to get_compiled_fn which strictly same as start_compile + store_vars = list(self.stack) + store_var_info = {} + + for name in restore_names: + _var = self.get_var(name) + if _var not in self.stack: + store_vars.append(_var) + store_var_info[_var.id] = name + + compile_fn = self._graph.get_compiled_fn(*store_vars) + + if compile_fn.graph_size() < ENV_MIN_GRAPH_SIZE.get(): return self._graph._restore_origin_opcode( list(self.stack), store_var_info, instr_idx ) else: - store_vars = list(self.stack) - for name in restore_names: - _var = self.get_var(name) - if _var not in self.stack: - store_vars.append(_var) - return self._graph._build_compile_fn_with_name_store([], store_vars) + return self._graph._build_compile_fn_with_name_store(store_vars) def _create_resume_fn(self, index, stack_size): """ @@ -1569,7 +1574,7 @@ def _create_resume_fn(self, index, stack_size): return fn, inputs @fallback_when_occur_error - def _break_graph_in_jump(self, result: TensorVariable, instr: Instruction): + def _break_graph_when_if(self, result: TensorVariable, instr: Instruction): """ Break the graph at a JUMP instruction. @@ -1644,7 +1649,7 @@ def _break_graph_in_jump(self, result: TensorVariable, instr: Instruction): self.guard_fn = self._graph.guard_fn @fallback_when_occur_error - def _break_graph_in_call( + def _break_graph_when_call( self, origin_stack: VariableStack, instr: Instruction, @@ -1708,6 +1713,22 @@ def _break_graph_in_call( self.guard_fn = self._graph.guard_fn def transform(self): + static_function = get_static_function(self._frame, "eval_frame") + if static_function is not None: + code = self._frame.f_code + inputs = [] + for i in range(code.co_argcount): + arg_name = code.co_varnames[i] + value = self._locals[arg_name] + inputs.append(value) + output = self._graph.call_ast(static_function, *inputs) + if output is not None: + self.stack.push(output) + self.RETURN_VALUE(None) + return ( + CustomCode(self.new_code, self.new_code is None), + self.guard_fn, + ) self.run() if self.new_code is self.empty_code: raise InnerError("OpExecutor return a empty new_code.") @@ -1769,7 +1790,7 @@ def _gen_loop_body_between( return pycode_gen.create_fn_with_inputs(inputs) @fallback_when_occur_error - def _break_graph_in_for_loop( + def _break_graph_when_for_loop( self, iterator: VariableBase, for_iter: Instruction ): ''' @@ -2040,7 +2061,7 @@ def FOR_ITER(self, instr): iterator.idx = backup_iter_idx self._graph.remove_global_guarded_variable(iterator) self.stack.push(iterator) - self._break_graph_in_for_loop(iterator, instr) + self._break_graph_when_for_loop(iterator, instr) return Stop(state="BreakGraph") def RETURN_VALUE(self, instr: Instruction): @@ -2048,7 +2069,8 @@ def RETURN_VALUE(self, instr: Instruction): len(self.stack) == 1 ), f"Stack must have one element, but get {len(self.stack)} elements." ret_val = self.stack.pop() - if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): + compile_fn = self._graph.get_compiled_fn(ret_val) + if compile_fn.graph_size() < ENV_MIN_GRAPH_SIZE.get(): self.new_code = None else: self._graph.start_compile(ret_val) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py index c24e94b07ffb26..9d6488dc4447a3 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -283,7 +283,7 @@ def RETURN_VALUE(self, instr: Instruction): self.return_value = self.stack.pop() return Stop(state="Return") - def _break_graph_in_jump(self, result, instr: Instruction): + def _break_graph_when_if(self, result, instr: Instruction): """ Helper method to raise a BreakGraphError when breaking the graph in a jump operation. @@ -292,7 +292,7 @@ def _break_graph_in_jump(self, result, instr: Instruction): instr (Instruction): The jump instruction. """ raise BreakGraphError( - "OpcodeInlineExecutor want call _break_graph_in_jump." + "OpcodeInlineExecutor want break graph when simulate `if`." ) def _create_resume_fn(self, index: int, stack_size: int = 0): diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py index 1e28a9402b6ab6..4edf14e5ca0d96 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -25,6 +25,7 @@ from .... import psdb from ....profiler import EventGuard from ....utils import ( + get_static_function, is_break_graph_api, is_break_graph_tensor_methods, is_builtin_fn, @@ -177,6 +178,13 @@ def call_function(self, /, *args, **kwargs) -> VariableBase: return result checkpoint = self.graph.save_memo() + + static_function = get_static_function(self.value, "inline_call") + if static_function is not None: + output = self.graph.call_ast(static_function, *args, **kwargs) + if output is not None: + return output + try: inline_executor = OpcodeInlineExecutor(self, *args, **kwargs) with EventGuard( diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index 465de3f6adf505..90cdf1bc366997 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -20,8 +20,9 @@ import paddle from paddle.amp.auto_cast import amp_state from paddle.base.data_feeder import convert_dtype -from paddle.framework import _dygraph_tracer +from paddle.framework import _dygraph_tracer, use_pir_api +from ..infer_meta import convert_meta_to_input_spec from ..profiler import EventGuard from ..utils import ( Cache, @@ -86,6 +87,23 @@ def amp_cast_inputs(self, args, kwargs): false_fn=lambda x: x, ) + def graph_size(self): + if self.partial_program is None: + input_spec = convert_meta_to_input_spec( + [self.SIR.symbol_meta_map[symbol] for symbol in self.SIR.inputs] + ) + ( + self.concrete_program, + self.partial_program, + ) = self.compiled_fn.get_concrete_program(input_spec) + self.partial_program.training = self.is_training + if use_pir_api(): + return len(self.partial_program.program.program.global_block().ops) + else: + if self.partial_program.program.num_blocks > 1: + return -1 + return len(self.partial_program.program.block(0).ops) + def __call__(self, *args, **kwargs): with EventGuard(f"FallbackWrapper: {self.SIR.name}"): if StepInfoManager().need_back_trace: diff --git a/python/paddle/jit/sot/symbolic/interpreter.py b/python/paddle/jit/sot/symbolic/interpreter.py index ac243e98ec41fd..ec49ecaec39a68 100644 --- a/python/paddle/jit/sot/symbolic/interpreter.py +++ b/python/paddle/jit/sot/symbolic/interpreter.py @@ -155,6 +155,10 @@ def layer(self, stmt, inputs): assert layer is not None, "SIR bound layer is None." return layer(*args, **kwargs) + def AST(self, stmt, inputs): + args, kwargs = inputs + return stmt.converted_func(*args, **kwargs) + def compile_sir(context: SymbolicTraceContext, name: str): """ diff --git a/python/paddle/jit/sot/symbolic/statement_ir.py b/python/paddle/jit/sot/symbolic/statement_ir.py index 1e0ab465e0bd84..edf2ab4aed16d9 100644 --- a/python/paddle/jit/sot/symbolic/statement_ir.py +++ b/python/paddle/jit/sot/symbolic/statement_ir.py @@ -19,10 +19,12 @@ """ from __future__ import annotations +import functools import weakref from typing import Any, Callable -from paddle.utils import is_sequence, map_structure +import paddle +from paddle.utils import flatten, map_structure from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend @@ -85,7 +87,7 @@ def __init__( outputs: list[Symbol], stacks: list[str], ): - assert type in ["call", "api", "method", "layer"] + assert type in ["call", "api", "method", "layer", "AST"] self.name = name self.inputs = inputs # (list of Symbols, dict of Symbols) self.outputs = outputs # list of Symbol | PythonObj @@ -96,9 +98,9 @@ def __init__( def __str__(self): def to_string(inps): - if isinstance(inps, str) or not is_sequence(inps): - return inps.__str__() - inps = (x.__str__() for x in inps) + inps = [x.__str__() for x in flatten(inps) if isinstance(x, Symbol)] + if len(inps) == 0: + return "(Empty)" return ", ".join(inps) return "{} || {} = {} ({}) ".format( @@ -158,12 +160,44 @@ def __init__( outputs: list[Symbol], stacks: list[str], ): + if isinstance(layer, Reference): + name = layer().__class__.__name__ + else: + name = layer.__class__.__name__ super().__init__( - "layer", layer.__class__.__name__, inputs, outputs, stacks + "layer", + name, + inputs, + outputs, + stacks, ) self.layer = layer +class ASTStatement(Statement): + def __init__( + self, + static_function, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + # this dygraph_function always has attr __code__, which is checked before + dygraph_func = static_function.dygraph_function + super().__init__( + "AST", + dygraph_func.__code__.co_name, + inputs, + outputs, + stacks, + ) + converted_func = paddle.jit.dy2static.convert_to_static(dygraph_func) + func_self = getattr(dygraph_func, '__self__', None) + if func_self is not None: + converted_func = functools.partial(converted_func, func_self) + self.converted_func = converted_func + + class StatementIR: """ StatementIR is the carrier that records the code for building the neural network model.It is @@ -181,6 +215,8 @@ def __init__(self, name: str): self.outputs = [] # list of Symbol | PythonObj self.statements = [] # list of Statement + self.symbol_meta_map = {} + def __len__(self): return len(self.statements) @@ -189,8 +225,14 @@ def __deepcopy__(self, memo=None): new_sir.inputs = list(self.inputs) new_sir.outputs = list(self.outputs) new_sir.statements = list(self.statements) + new_sir.symbol_meta_map = dict(self.symbol_meta_map.items()) return new_sir + def set_symbol_meta_map(self, meta_map): + # if the meta of a input symbol inplace changed, we should get the origin meta as input of SIR + meta_map.update(self.symbol_meta_map) + self.symbol_meta_map = meta_map + def add_input(self, input): self.inputs.append(input) @@ -230,10 +272,6 @@ def __str__(self): def __repr__(self): return self.__str__() - def graph_size(self): - call_layers = [x for x in self.statements if x.type == "layer"] - return len(self.statements) + len(call_layers) - @Singleton class StatementIRFactory: diff --git a/python/paddle/jit/sot/symbolic/symbolic_context.py b/python/paddle/jit/sot/symbolic/symbolic_context.py index 47f40bbcc9ec74..931586645149ad 100644 --- a/python/paddle/jit/sot/symbolic/symbolic_context.py +++ b/python/paddle/jit/sot/symbolic/symbolic_context.py @@ -18,6 +18,7 @@ from .compile_cache import CompileSIRCache from .statement_ir import ( ApiStatement, + ASTStatement, CallStatement, LayerStatement, MethodStatement, @@ -69,7 +70,6 @@ def call_API(self, api, inputs, outputs, stacks): """ Call a paddle api. """ - assert callable(api), "call_API must receive a paddle api." stmt = ApiStatement(api, inputs, outputs, stacks) self.TOS.add_statement(stmt) @@ -94,6 +94,10 @@ def call_LAYER(self, layer, inputs, outputs, stacks): stmt = LayerStatement(layer, inputs, outputs, stacks) self.TOS.add_statement(stmt) + def call_AST(self, static_function, inputs, outputs, stacks): + stmt = ASTStatement(static_function, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + def get_sir(self, name: str): """ Get a SIR from statement_factory. @@ -130,14 +134,18 @@ def compile_do_nothing(self, ret_vals): ret_vals (list[Symbol]): the return values of the function. """ - def dummy_func(*args, **kwargs): - return [] + class DummyFunc: + def __call__(*args, **kwargs): + return [] + + def graph_size(self): + return 0 # return None function dummy_stmt_ir = StatementIR("dummy_func") dummy_stmt_ir.outputs = [] dummy_stmt_ir.inputs = [] - return dummy_func, dummy_stmt_ir + return DummyFunc(), dummy_stmt_ir def compile_fn(self, ret_vals, **kwargs): """ diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 307ef1c21b8008..16e2cd5b1afe52 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -19,9 +19,11 @@ ENV_SHOW_TRACKERS, ENV_SOT_LOG_LEVEL, ENV_STRICT_MODE, + ENV_SOT_WITH_CONTROL_FLOW, cost_model_guard, min_graph_size_guard, strict_mode_guard, + with_control_flow_guard, ) from .exceptions import ( # noqa: F401 BreakGraphError, @@ -50,6 +52,7 @@ current_tmp_name_records, execute_time, flatten_extend, + flatten, get_unbound_method, hashable, in_paddle_module, @@ -69,3 +72,4 @@ no_eval_frame, tmp_name_guard, ) +from .call_ast_utils import get_static_function, try_ast_func diff --git a/python/paddle/jit/sot/utils/call_ast_utils.py b/python/paddle/jit/sot/utils/call_ast_utils.py new file mode 100644 index 00000000000000..612334287b0a56 --- /dev/null +++ b/python/paddle/jit/sot/utils/call_ast_utils.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import types + +import paddle + +from .envs import ENV_SOT_WITH_CONTROL_FLOW +from .exceptions import InnerError +from .utils import Singleton + +try_ast_codes = set() + + +def try_ast_func(func): + def _is_wrapped(f): + return hasattr(f, '__wrapped__') + + unwrapped_f = func + if hasattr(unwrapped_f, "__code__"): + try_ast_codes.add(func.__code__) + + while _is_wrapped(unwrapped_f): + unwrapped_f = unwrapped_f.__wrapped__ + if hasattr(unwrapped_f, "__code__"): + try_ast_codes.add(func.__code__) + + return func + + +@Singleton +class StaticFunctionManager: + def __init__(self): + self.code_map = {} + + def ast_transform_with_frame(self, frame): + code = frame.f_code + if code not in try_ast_codes: + return None + if code not in self.code_map: + if code.co_name.startswith("#") or code.co_name.startswith("$"): + self.code_map[code] = None + elif len(code.co_cellvars) + len(code.co_freevars) != 0: + self.code_map[code] = None + else: + function = types.FunctionType( + code, + frame.f_globals, + code.co_name, + (), + (), + ) + function = paddle.jit.to_static(function, full_graph=True) + self.code_map[code] = function + + return self.code_map[code] + + def ast_transform_with_callable(self, fn): + if not inspect.isfunction(fn) or not hasattr(fn, "__code__"): + return None + + code = fn.__code__ + if code not in try_ast_codes: + return None + if code not in self.code_map: + if code.co_name.startswith("#") or code.co_name.startswith("$"): + self.code_map[code] = None + elif len(code.co_cellvars) + len(code.co_freevars) != 0: + self.code_map[code] = None + else: + self.code_map[code] = paddle.jit.to_static(fn, full_graph=True) + + return self.code_map[code] + + +def get_static_function(obj, type_): + if ENV_SOT_WITH_CONTROL_FLOW.get(): + if type_ == "eval_frame": + return StaticFunctionManager().ast_transform_with_frame(obj) + elif type_ == "inline_call": + return StaticFunctionManager().ast_transform_with_callable(obj) + else: + raise InnerError(f"Can not get static function with type {type_}.") + return None diff --git a/python/paddle/jit/sot/utils/envs.py b/python/paddle/jit/sot/utils/envs.py index a7d8ceafb7f0cb..bc6879664890ef 100644 --- a/python/paddle/jit/sot/utils/envs.py +++ b/python/paddle/jit/sot/utils/envs.py @@ -29,6 +29,9 @@ ENV_STRICT_MODE = BooleanEnvironmentVariable("STRICT_MODE", False) ENV_SHOW_TRACKERS = StringEnvironmentVariable("SHOW_TRACKERS", "") ENV_CLEAN_CODE = BooleanEnvironmentVariable("CLEAN_CODE", False) +ENV_SOT_WITH_CONTROL_FLOW = BooleanEnvironmentVariable( + "SOT_WITH_CONTROL_FLOW", True +) @contextmanager @@ -47,3 +50,9 @@ def strict_mode_guard(value: bool): def min_graph_size_guard(value: int): with EnvironmentVariableGuard(ENV_MIN_GRAPH_SIZE, value): yield + + +@contextmanager +def with_control_flow_guard(value: bool): + with EnvironmentVariableGuard(ENV_SOT_WITH_CONTROL_FLOW, value): + yield diff --git a/test/custom_runtime/test_custom_cpu_to_static.py b/test/custom_runtime/test_custom_cpu_to_static.py index 60ba27004afbdd..9de01f378d71a7 100644 --- a/test/custom_runtime/test_custom_cpu_to_static.py +++ b/test/custom_runtime/test_custom_cpu_to_static.py @@ -164,7 +164,9 @@ def forward(self, x): # convert to static model build_strategy = paddle.static.BuildStrategy() - mnist = paddle.jit.to_static(model, build_strategy=build_strategy) + mnist = paddle.jit.to_static( + model, build_strategy=build_strategy, full_graph=True + ) # data loader transform = paddle.vision.transforms.Compose( diff --git a/test/sot/test_03_tuple.py b/test/sot/test_03_tuple.py index 797d54384714d0..d0db1d100a42ce 100644 --- a/test/sot/test_03_tuple.py +++ b/test/sot/test_03_tuple.py @@ -24,6 +24,7 @@ import paddle from paddle.jit.sot.psdb import check_no_breakgraph +from paddle.jit.sot.utils import with_control_flow_guard @check_no_breakgraph @@ -80,6 +81,7 @@ def test_tuple_methods_int(self): self.assert_results(tuple_count_int, 1, paddle.to_tensor(2)) self.assert_results(tuple_index_int, 1, paddle.to_tensor(2)) + @with_control_flow_guard(False) def test_tuple_methods_tensor(self): a = paddle.to_tensor(1) b = paddle.to_tensor(2) diff --git a/test/sot/test_call_ast.py b/test/sot/test_call_ast.py new file mode 100644 index 00000000000000..e893af485e4f1a --- /dev/null +++ b/test/sot/test_call_ast.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.utils import try_ast_func, with_control_flow_guard + + +@try_ast_func +def calc(x, y, z): + if x < 5: + a = x + y + b = y - z + c = a * b + return c + else: + a = x - y + b = y + z + c = a * b + return c + + +def inline_call_ast(x, y): + a = x - y + 3 + b = x + y + c = x * y + z = calc(a, b, c) + return z + a + + +class TestNumpyAdd(TestCaseBase): + @with_control_flow_guard(True) + def test_full_graph_ast(self): + x = paddle.to_tensor([2]) + y = paddle.to_tensor([3]) + z = paddle.to_tensor([4]) + self.assert_results(calc, x, y, z) + + @with_control_flow_guard(True) + def test_inline_ast(self): + x = paddle.to_tensor([2]) + y = paddle.to_tensor([3]) + self.assert_results(inline_call_ast, x, y) + + +if __name__ == "__main__": + unittest.main()