diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 8f87e19cd4d288..dc57b252e00c2d 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -39,6 +39,7 @@ from ...utils import ( ENV_SHOW_TRACKERS, NameGenerator, + SotUndefinedVar, inner_error_default_handler, is_inplace_api, is_paddle_api, @@ -140,6 +141,20 @@ def get_params_and_non_param_symbol(*args, **kwargs): return params, non_params +class VariableLoader: + def __init__(self, store_var_info, pycode_gen): + self._store_var_info = store_var_info + self._pycode_gen: PyCodeGen = pycode_gen + + def load(self, var): + if var is SotUndefinedVar(): + self._pycode_gen.gen_load_const(SotUndefinedVar()) + elif isinstance(var, NullVariable): + var.reconstruct(self._pycode_gen) + else: + self._pycode_gen.gen_load(self._store_var_info[var.id]) + + class FunctionGraph: """ A Graph representation corresponding to each FunctionFrame @@ -281,17 +296,6 @@ def guard_fn(self) -> Guard: return make_guard(guards) def _restore_origin_opcode(self, stack_vars, store_var_info, instr_idx): - class VariableLoader: - def __init__(self, store_var_info, pycode_gen): - self._store_var_info = store_var_info - self._pycode_gen: PyCodeGen = pycode_gen - - def load(self, var): - if isinstance(var, NullVariable): - var.reconstruct(self._pycode_gen) - return - self._pycode_gen.gen_load(self._store_var_info[var.id]) - origin_instrs = get_instructions(self.pycode_gen._origin_code) is_precall = origin_instrs[instr_idx].opname == "PRECALL" current_idx = instr_idx @@ -308,7 +312,7 @@ def load(self, var): restore_instr_names = restore_instr_names[:-1] self.pycode_gen.extend_instrs(restore_instrs) - nop = self.pycode_gen._add_instr("NOP") + nop = self.pycode_gen.add_instr("NOP") for instr in origin_instrs: if instr.jump_to == origin_instrs[current_idx]: @@ -324,26 +328,21 @@ def load(self, var): name_gen = NameGenerator("__start_compile_saved_orig_") + # here is not update changed values, it just give names to stack vars + # and want keep same interface as _build_compile_fn_with_name_store for var in stack_vars[::-1]: - store_var_info[var.id] = name_gen.next() - self.pycode_gen.gen_store_fast(store_var_info[var.id]) + if store_var_info[var.id] is None: + store_var_info[var.id] = name_gen.next() + self.pycode_gen.gen_store_fast(store_var_info[var.id]) + else: + self.pycode_gen.gen_store( + store_var_info[var.id], self.pycode_gen._origin_code + ) return VariableLoader(store_var_info, self.pycode_gen) - def _build_compile_fn_with_name_store(self, to_store_vars): - class VariableLoader: - def __init__(self, index_for_load, pycode_gen): - self._index_for_load = index_for_load - self._pycode_gen: PyCodeGen = pycode_gen - - def load(self, var, allow_push_null=True): - if isinstance(var, NullVariable): - var.reconstruct(self._pycode_gen) - return - self._pycode_gen.gen_load(self._index_for_load[var.id]) - + def _build_compile_fn_with_name_store(self, to_store_vars, store_var_info): # var_id -> local_name mapping - index_for_load = {} to_store_vars = list( filter(lambda x: not isinstance(x, NullVariable), to_store_vars) ) @@ -351,19 +350,15 @@ def load(self, var, allow_push_null=True): name_gen = NameGenerator("__start_compile_saved_") for var in to_store_vars[::-1]: - index_for_load[var.id] = name_gen.next() - - def _log_fn(): - print( - f"[StartCompile] saved var: {index_for_load[var.id]} = ", - var, + if store_var_info[var.id] is None: + store_var_info[var.id] = name_gen.next() + self.pycode_gen.gen_store_fast(store_var_info[var.id]) + else: + self.pycode_gen.gen_store( + store_var_info[var.id], self.pycode_gen._origin_code ) - log_do(4, _log_fn) - - self.pycode_gen.gen_store_fast(index_for_load[var.id]) - - return VariableLoader(index_for_load, self.pycode_gen) + return VariableLoader(store_var_info, self.pycode_gen) def get_compiled_fn(self, *ret_vars): ret_items = [ diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index e9a985e5b728c7..e0ada6a9b74fa9 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -27,8 +27,6 @@ import opcode -from paddle.jit.utils import OrderedSet - from ...profiler import EventGuard, event_register from ...psdb import NO_BREAKGRAPH_CODES from ...utils import ( @@ -45,8 +43,7 @@ from ..instruction_utils import ( Instruction, Space, - analysis_inputs, - analysis_used_names_with_space, + analysis_used_names, calc_stack_effect, get_instructions, ) @@ -416,7 +413,34 @@ def transform(self): """ raise NotImplementedError() - def get_var(self, name: str): + def find_space_of_var_name(self, name): + code = self._graph.pycode_gen._origin_code + if name in (code.co_freevars + code.co_cellvars): + return Space.cells + elif name in code.co_varnames: + return Space.locals + elif name in code.co_names: + return Space.globals + else: + return Space.not_found + + def has_var(self, name: str): + space = self.find_space_of_var_name(name) + + if space == Space.locals: + return name in self._locals + elif space == Space.cells: + return name in self._cells + elif space == Space.globals: + return name in set( + chain( + self._globals.keys(), + self._builtins.keys(), + ) + ) + return False + + def get_var(self, name: str, allow_undefined=False): """ Gets the variable with the given name. @@ -438,31 +462,27 @@ def get_var(self, name: str): return self._globals.get(name) elif name in self._builtins.keys(): return self._builtins[name] + elif allow_undefined: + return SotUndefinedVar() else: raise InnerError(f'Can not get var: {name}') - def has_var(self, name: str, space: str = "any"): - if space == "any": - return name in set( - chain( - self._locals.keys(), - self._cells.keys(), - self._globals.keys(), - self._builtins.keys(), - ) - ) - elif space == Space.locals: - return name in self._locals + def set_var(self, name: str, value: VariableBase): + space = self.find_space_of_var_name(name) + + # if name is new created, we always place it to locals + if space in (Space.locals, Space.not_found): + self._locals[name] = value elif space == Space.cells: - return name in self._cells + self._cells[name].set_value(value) elif space == Space.globals: - return name in set( - chain( - self._globals.keys(), - self._builtins.keys(), - ) - ) - return False + self._globals[name] = value + + def _find_names_in_space(self, names, space): + target_names = [ + name for name in names if self.find_space_of_var_name(name) in space + ] + return target_names def pop_call_stack_until_self(self): """ @@ -1511,6 +1531,31 @@ def __init__(self, frame: types.FrameType, **kwargs): super().__init__(frame.f_code, graph) Dispatcher.graph = graph + def transform(self): + static_function = get_static_function(self._frame, "eval_frame") + if static_function is not None: + code = self._frame.f_code + inputs = [] + for i in range(code.co_argcount): + arg_name = code.co_varnames[i] + value = self._locals[arg_name] + inputs.append(value) + output = self._graph.call_ast(static_function, *inputs) + if output is not None: + self.stack.push(output) + self.RETURN_VALUE(None) + return ( + CustomCode(self.new_code, self.new_code is None), + self.guard_fn, + ) + self.run() + if self.new_code is self.empty_code: + raise InnerError("OpExecutor return a empty new_code.") + return ( + CustomCode(self.new_code, self.new_code is None), + self.guard_fn, + ) + def cleanup(self): self._graph.pycode_gen = None Dispatcher.graph = None @@ -1560,56 +1605,99 @@ def _prepare_virtual_env(self): ) ) - def gen_compute_in_break_with_name_store(self, restore_names, instr_idx): + def FOR_ITER(self, instr): + iterator = self.stack.pop() + backup_iter_idx = None + + start = self.indexof(instr) + end = self.indexof(instr.jump_to) + for i in range(start, end): + if self._instructions[i].opname == "RETURN_VALUE": + raise FallbackError("Found RETURN_VALUE in for loop body.") + + self._graph.add_global_guarded_variable(iterator) + + try: + if not isinstance(iterator, SequenceIterVariable): + raise BreakGraphError( + f"Can not simulate iterator of {type(iterator)}." + ) + + backup_iter_idx = iterator.idx + + self._inline_call_for_loop(iterator, instr) + self._lasti = self.indexof(instr.jump_to) + next_instr = self._instructions[self._lasti] + self._lasti += int(next_instr.opname == 'END_FOR') + except BreakGraphError as e: + log(3, f"[BreakGraph] FOR_ITER sim for loop failed for: {e}\n") + if backup_iter_idx: + iterator.idx = backup_iter_idx + self._graph.remove_global_guarded_variable(iterator) + self.stack.push(iterator) + self._break_graph_when_for_loop(iterator, instr) + return Stop(state="BreakGraph") + + def RETURN_VALUE(self, instr: Instruction): + assert ( + len(self.stack) == 1 + ), f"Stack must have one element, but get {len(self.stack)} elements." + ret_val = self.stack.pop() + return self.compile_return(ret_val) + + def RETURN_CONST(self, instr: Instruction): + ret_const = self._co_consts[instr.arg] + return self.compile_return(ret_const) + + def compile_return(self, ret_val): + compile_fn = self._graph.get_compiled_fn(ret_val) + if compile_fn.graph_size() < ENV_MIN_GRAPH_SIZE.get(): + self.new_code = None + else: + self._graph.start_compile(ret_val) + self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + return Stop(state="Return") + + def get_compute_fn_and_update_changed_vars( + self, restore_names, stack, end_idx + ): """ - branch 1: if the graph size is too small, just run in dygraph - branch 2: if the graph is big enough, create compiled_fn - - This api will generator opcodes in different situation, the generated codes - will do the same thing as origin code. - - restore_names: - the names used in resume functions, branch 2 will restore these values, - branch 1 also need these names for generating opcode, but they are not - needed to be restored - instr_idx: - the index for branch 1 to find the boundary and copy origin opcode + this function will: + 1. add opcodes to self._graph.pycode_gen, which do the same thing as origin code. + 2. update the value of whom would be changed in generated codes + + This api will generator opcodes in different situation, + branch 1: if the graph size is too small, just run in dygraph. + branch 2: if the graph is big enough, create compiled_fn. + + Params: + restore_names: the names used in resume functions. + end_idx: instruction index where simulation get break. + stack: current stack """ - # if we want get compiled fn, and do not do ast twice, - # we must give retval to get_compiled_fn which strictly same as start_compile - store_vars = list(self.stack) - store_var_info = {} + store_vars = list(stack) + store_var_info = {var.id: None for var in stack} for name in restore_names: - _var = self.get_var(name) - if _var not in self.stack: + _var = self.get_var(name, allow_undefined=True) + if _var is SotUndefinedVar(): + continue + if _var not in stack: store_vars.append(_var) - store_var_info[_var.id] = name + store_var_info[_var.id] = name compile_fn = self._graph.get_compiled_fn(*store_vars) if compile_fn.graph_size() < ENV_MIN_GRAPH_SIZE.get(): return self._graph._restore_origin_opcode( - list(self.stack), store_var_info, instr_idx + list(stack), store_var_info, end_idx ) else: - return self._graph._build_compile_fn_with_name_store(store_vars) - - def _create_resume_fn(self, index, stack_size): - """ - Create a resume function and its inputs at the specified index. - - Args: - index: The index at which the resume function is created. - stack_size: The size of the stack. - - Returns: - The resume function and its inputs. - - """ - pycode_gen = PyCodeGen(self._frame) - fn, inputs = pycode_gen.gen_resume_fn_at(index, stack_size) - return fn, inputs + return self._graph._build_compile_fn_with_name_store( + store_vars, store_var_info + ) @fallback_when_occur_error def _break_graph_when_if(self, result: TensorVariable, instr: Instruction): @@ -1622,66 +1710,105 @@ def _break_graph_when_if(self, result: TensorVariable, instr: Instruction): """ self._graph.add_global_guarded_variable(result) - # minus the bool value - stack_size = len(self.stack) - 1 - # gen call static fn opcode - if_fn, if_inputs = self._create_resume_fn( - self.indexof(instr) + 1, stack_size + # 1. analyse info + cur_index = self.indexof(instr) + true_fn_start_index = cur_index + 1 + false_fn_start_index = self.indexof(instr.jump_to) + stack_size_after_if = len(self.stack) - 1 + + # 2. create true_fn and false_fn + def create_if_branch_fn(start_idx, input_var_names): + if self._instructions[start_idx].opname == "RETURN_VALUE": + return None + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) + pycode_gen.set_function_inputs( + input_var_names, stack_size=stack_size_after_if + ) + pycode_gen.extend_instrs(origin_instrs[start_idx:]) + # the resume_fn contains return code, so we don't need set output here + # global vars are updated correctly, and need local vars will return + resume_fn = pycode_gen.create_function() + return resume_fn + + true_fn_read_names, _ = analysis_used_names( + self._instructions, self.indexof(instr) + 1 + ) + true_fn_input_var_names = self._find_names_in_space( + true_fn_read_names, (Space.locals, Space.cells) + ) + + true_fn = create_if_branch_fn( + start_idx=true_fn_start_index, + input_var_names=true_fn_input_var_names, + ) + + false_fn_read_names, _ = analysis_used_names( + self._instructions, self.indexof(instr.jump_to) + ) + false_fn_input_var_names = self._find_names_in_space( + false_fn_read_names, (Space.locals, Space.cells) ) - else_fn, else_inputs = self._create_resume_fn( - self.indexof(instr.jump_to), stack_size + + false_fn = create_if_branch_fn( + start_idx=false_fn_start_index, + input_var_names=false_fn_input_var_names, ) - inputs_names = if_inputs | else_inputs + # 4. setup vars which is created in loop as Undefind + for name in true_fn_input_var_names[:-1]: + if not self.has_var(name): + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) + for name in false_fn_input_var_names: + if not self.has_var(name): + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) - var_loader = self.gen_compute_in_break_with_name_store( - inputs_names, self.indexof(instr) + # 4. compile codes before if + update_var_names = list(true_fn_read_names | false_fn_read_names) + var_loader = self.get_compute_fn_and_update_changed_vars( + update_var_names, self.stack, cur_index ) + # 5. create if sturcture and call true_fn and false_fn var_loader.load(result) - # the result is used by if opcode, and should not be input of resume_fn - self.stack.pop() + if_code = self._graph.pycode_gen.add_instr(instr.opname) - # gen call if/else resume fn opcode - if if_fn is not None: - self._graph.pycode_gen.gen_load_object( - if_fn, if_fn.__code__.co_name - ) - insert_index = len(self._graph.pycode_gen._instructions) - 1 - for i, stack_arg in enumerate(self.stack): - var_loader.load(stack_arg) - for name in if_inputs: - var_loader.load(self.get_var(name)) - self._graph.pycode_gen.gen_call_function( - argc=if_fn.__code__.co_argcount, - ) - self._graph.pycode_gen.gen_return() - else: - insert_index = len(self._graph.pycode_gen._instructions) - 1 - self._graph.pycode_gen.gen_return() + assert true_fn is not None - if else_fn is not None: - self._graph.pycode_gen.gen_load_object( - else_fn, else_fn.__code__.co_name + self._graph.pycode_gen.gen_load_object( + true_fn, true_fn.__code__.co_name + ) + for stack_arg in list(self.stack)[:-1]: + var_loader.load(stack_arg) + + for name in true_fn_input_var_names: + var_loader.load(self.get_var(name, allow_undefined=True)) + + self._graph.pycode_gen.gen_call_function( + argc=true_fn.__code__.co_argcount, + ) + self._graph.pycode_gen.gen_return() + + if false_fn is not None: + false_start_code = self._graph.pycode_gen.gen_load_object( + false_fn, false_fn.__code__.co_name ) - jump_to = self._graph.pycode_gen._instructions[-1] - for i, stack_arg in enumerate(self.stack): + for stack_arg in list(self.stack)[:-1]: var_loader.load(stack_arg) - for name in else_inputs: - var_loader.load(self.get_var(name)) + for name in false_fn_input_var_names: + var_loader.load(self.get_var(name, allow_undefined=True)) + self._graph.pycode_gen.gen_call_function( - argc=else_fn.__code__.co_argcount, + argc=false_fn.__code__.co_argcount, ) self._graph.pycode_gen.gen_return() else: - self._graph.pycode_gen.gen_return() - jump_to = self._graph.pycode_gen._instructions[-1] + false_start_code = self._graph.pycode_gen.gen_return() - # gen jump opcode - self._graph.pycode_gen._insert_instr( - insert_index, instr.opname, jump_to=jump_to - ) + if_code.jump_to = false_start_code self.new_code = self._graph.pycode_gen.gen_pycode() self.guard_fn = self._graph.guard_fn @@ -1702,41 +1829,60 @@ def _break_graph_when_call( push_n: The number of elements to be pushed onto the stack. """ + self.stack = origin_stack + + # 1. collect infomations push_n = push_n(instr.arg) if callable(push_n) else push_n is_precall = instr.opname == "PRECALL" - index = self.indexof(instr) + cur_index = self.indexof(instr) # Use CALL instead of PRECALL to calculate the real stack effect - call_instr = self._instructions[index + int(is_precall)] + call_instr = self._instructions[cur_index + int(is_precall)] # skip CALL if current instr is PRECALL - next_index = index + 1 + int(is_precall) - self.stack = origin_stack - - # gen call static fn opcode + next_index = cur_index + 1 + int(is_precall) + stack_effect = calc_stack_effect(call_instr) + pop_n = push_n - stack_effect + stack_size_after_call = len(self.stack) - pop_n + push_n - resume_input_name = analysis_inputs(self._instructions, next_index) + # 2. create resume function + read_names, _ = analysis_used_names(self._instructions, next_index) - var_loader = self.gen_compute_in_break_with_name_store( - resume_input_name, index + input_var_names = self._find_names_in_space( + read_names, (Space.locals, Space.cells) ) - # gen graph break call fn opcode - stack_effect = calc_stack_effect(call_instr) - pop_n = push_n - stack_effect + def create_resume_fn(): + if self._instructions[next_index].opname == "RETURN_VALUE": + return None + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) + pycode_gen.set_function_inputs( + input_var_names, stack_size=stack_size_after_call + ) + pycode_gen.extend_instrs(origin_instrs[next_index:]) + # the resume_fn contains return code, so we don't need set output here + # global vars are updated correctly, and need local vars will return + resume_fn = pycode_gen.create_function() + return resume_fn - for i, stack_arg in enumerate(self.stack): + resume_fn = create_resume_fn() + + # 3. compile sub graph before call + var_loader = self.get_compute_fn_and_update_changed_vars( + read_names, self.stack, cur_index + ) + + # 4. recover stack + for stack_arg in self.stack: var_loader.load(stack_arg) - # gen call resume fn opcode + # 5. run the break CALL with origin python # NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None. self._graph.pycode_gen.gen_kw_names(self._call_shape) self._graph.pycode_gen.extend_instrs( - self._instructions[index:next_index] + self._instructions[cur_index:next_index] ) - self.stack.pop_n(pop_n) - stack_size = len(self.stack) + push_n - - resume_fn, _ = self._create_resume_fn(next_index, stack_size) + # 6. run resume fn if resume_fn: self._graph.pycode_gen.gen_load_object( resume_fn, resume_fn.__code__.co_name @@ -1744,9 +1890,11 @@ def _break_graph_when_call( # NOTE(zrr1999): We need to shift the resume_fn under its arguments. # In Python 3.11+, NULL + resume_fn should be shifted together. shift_n = 2 if sys.version_info >= (3, 11) else 1 - self._graph.pycode_gen.gen_shift_n(shift_n, stack_size + shift_n) - for name in resume_input_name: - var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_shift_n( + shift_n, stack_size_after_call + shift_n + ) + for name in input_var_names: + var_loader.load(self.get_var(name, allow_undefined=True)) self._graph.pycode_gen.gen_call_function( argc=resume_fn.__code__.co_argcount, ) @@ -1757,112 +1905,14 @@ def _break_graph_when_call( self.new_code = self._graph.pycode_gen.gen_pycode() self.guard_fn = self._graph.guard_fn - def transform(self): - static_function = get_static_function(self._frame, "eval_frame") - if static_function is not None: - code = self._frame.f_code - inputs = [] - for i in range(code.co_argcount): - arg_name = code.co_varnames[i] - value = self._locals[arg_name] - inputs.append(value) - output = self._graph.call_ast(static_function, *inputs) - if output is not None: - self.stack.push(output) - self.RETURN_VALUE(None) - return ( - CustomCode(self.new_code, self.new_code is None), - self.guard_fn, - ) - self.run() - if self.new_code is self.empty_code: - raise InnerError("OpExecutor return a empty new_code.") - return ( - CustomCode(self.new_code, self.new_code is None), - self.guard_fn, - ) - - def _gen_loop_body_between( - self, inputs: list, for_iter_idx: int, start: int, end: int - ) -> types.FunctionType: - """ - Generates the loop body between the specified indices in the instruction list. - - Args: - inputs: function inputs infos - for_iter_idx (int): For find the for_iter opcode - start (int): The start index of the loop body. - end (int): The end index of the loop body. - - Returns: - tuple: The generated loop body function object and its inputs. - - """ - pycode_gen = PyCodeGen(self._frame) - origin_instrs = get_instructions(pycode_gen._origin_code) - - for_iter = origin_instrs[for_iter_idx] - - # for balance the stack (the loop body will pop iter first before break or return) - # this None is used for replace the iterator obj in stack top - pycode_gen.gen_load_const(None) - - # extend loop body main logic - pycode_gen.extend_instrs(origin_instrs[start:end]) - - # break should jump to this nop - nop_for_break = pycode_gen._add_instr("NOP") - - # need do additional operates when break - pycode_gen.gen_load_const(False) - pycode_gen.gen_store_fast(inputs[-1]) - pycode_gen.gen_load_const(None) # keep stack balance - - # continue should jump to this nop - nop_for_continue = pycode_gen._add_instr("NOP") - pycode_gen.gen_pop_top() - - # relocate jump - out_loop = for_iter.jump_to - for instr in pycode_gen._instructions: - if instr.jump_to == for_iter: - instr.jump_to = nop_for_continue - if instr.jump_to == out_loop: - instr.jump_to = nop_for_break - - # outputs is the same as inputs - pycode_gen.gen_outputs_and_return(inputs) - return pycode_gen.create_fn_with_inputs(inputs) - @fallback_when_occur_error def _break_graph_when_for_loop( self, iterator: VariableBase, for_iter: Instruction ): - ''' - for_iter: the FOR_ITER opcode - - need find out opcodes which unpack value from FOR_ITER, by analysing stack - - case 1: - for i in iter: - - FOR_ITER - STORE_FAST i - - case 2: - for i,j in iter: - - FOR_ITER - UNPACK_SEQUENCE 2 - STORE_FAST i - STORE_FAST j - - TODO: check var is in globals or builtins, only locals considered now - ''' - # 0. prepare sub functions - # 0.1 find the range of loop body + # 1. find the range of loop body assert for_iter.jump_to is not None - loop_body_start_idx = self.indexof(for_iter) + 1 + for_iter_idx = self.indexof(for_iter) + loop_body_start_idx = for_iter_idx + 1 loop_body_end_idx = self.indexof(for_iter.jump_to) curent_stack = 1 @@ -1877,122 +1927,170 @@ def _break_graph_when_for_loop( if curent_stack == 0: break - # 0.2 create loop body function - all_used_vars = analysis_used_names_with_space( + # 2. create loop body function + loop_body_read_names, loop_body_write_names = analysis_used_names( self._instructions, loop_body_start_idx, loop_body_end_idx ) - loop_body_inputs = [ - k - for k, v in all_used_vars.items() - if v in (Space.locals, Space.cells) - ] + ["_break_flag"] - - loop_body_fn = self._gen_loop_body_between( - loop_body_inputs, - self.indexof(for_iter), - loop_body_start_idx, - loop_body_end_idx, - ) + loop_body_inputs = self._find_names_in_space( + loop_body_read_names | loop_body_write_names, + (Space.locals, Space.cells), + ) + ["_break_flag"] + loop_body_outputs = list(loop_body_write_names) + ["_break_flag"] - log(3, "[Resumed Function]: break graph in loop create loop body as\n") - log_do(3, lambda: dis.dis(loop_body_fn)) + def create_loop_body(): + pycode_gen = PyCodeGen(self._frame) - # 0.3 create after loop part function, minus 1 for iterator - after_loop_fn, fn_inputs = self._create_resume_fn( - loop_body_end_idx, len(self.stack) - 1 - ) + pycode_gen.set_function_inputs(loop_body_inputs, stack_size=0) - total_inputs = OrderedSet(list(fn_inputs) + list(loop_body_inputs[:-1])) + origin_instrs = get_instructions(pycode_gen._origin_code) + for_iter = origin_instrs[for_iter_idx] - # 1. part before for-loop, start compile - ret_names = [ - name - for name in total_inputs - if name in chain(self._locals, self._cells) - ] + # for balance the stack (the loop body will pop iter first before break or return) + # this None is used for replace the iterator obj in stack top + pycode_gen.gen_load_const(None) + + # extend loop body main logic + pycode_gen.extend_instrs( + origin_instrs[loop_body_start_idx:loop_body_end_idx] + ) + + # break should jump to this nop + nop_for_break = pycode_gen.add_instr("NOP") + + # need do additional operates when break + pycode_gen.gen_load_const(False) + pycode_gen.gen_store_fast(loop_body_inputs[-1]) + pycode_gen.gen_load_const(None) # keep stack balance + + # continue should jump to this nop + nop_for_continue = pycode_gen.add_instr("NOP") + pycode_gen.gen_pop_top() + + # relocate jump + out_loop = for_iter.jump_to + for instr in pycode_gen._instructions: + if instr.jump_to == for_iter: + instr.jump_to = nop_for_continue + if instr.jump_to == out_loop: + instr.jump_to = nop_for_break + + # outputs is the same as inputs + pycode_gen.set_function_outputs(loop_body_outputs) + loop_body_fn = pycode_gen.create_function() + + log( + 3, + "[Resumed Function]: break graph in loop create loop body as\n", + ) + log_do(3, lambda: dis.dis(loop_body_fn)) - var_loader = self.gen_compute_in_break_with_name_store( - ret_names, self.indexof(for_iter) + return loop_body_fn + + loop_body_fn = create_loop_body() + + # 3. create after loop part function, stack size minus 1 for iterator + after_loop_read_names, _ = analysis_used_names( + self._instructions, loop_body_end_idx, len(self._instructions) + ) + after_loop_fn_inputs = self._find_names_in_space( + after_loop_read_names, (Space.locals, Space.cells) ) - # 2. restore vars with origin name - for name in ret_names: - var_loader.load(self.get_var(name)) - self._graph.pycode_gen.gen_store(name, self._code) + def create_after_loop_fn(): + if self._instructions[loop_body_end_idx].opname == "RETURN_VALUE": + return None + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) + pycode_gen.set_function_inputs( + after_loop_fn_inputs, stack_size=len(self.stack) - 1 + ) + pycode_gen.extend_instrs(origin_instrs[loop_body_end_idx:]) + # the resume_fn contains return code, so we don't need set output here + # global vars are updated correctly, and need local vars will return + after_loop_fn = pycode_gen.create_function() + return after_loop_fn - # 3. setup vars which is created in loop as Undefind - undefined_names = set() + after_loop_fn = create_after_loop_fn() + + # 4. setup vars which is created in loop as Undefind for name in loop_body_inputs[:-1]: - if not self.has_var(name, all_used_vars[name]): - undefined_names.add(name) + if not self.has_var(name): + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) + for name in after_loop_fn_inputs: + if not self.has_var(name): self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) self._graph.pycode_gen.gen_store(name, self._code) - # 4.1 load iterator + # 5. compile sub graph before for-loop + update_names = list(loop_body_read_names | after_loop_read_names) + var_loader = self.get_compute_fn_and_update_changed_vars( + update_names, self.stack, self.indexof(for_iter) + ) + + # 6. prepare a new loop and call loop body + # 6.1. load iterator, it is in stack, so we can load it with var_loader var_loader.load(iterator) self.stack.pop() - # 4.2 gen FOR_ITER and unpack data + # 6.2. copy FOR_ITER and unpack logic self._graph.pycode_gen.extend_instrs( - self._instructions[self.indexof(for_iter) : loop_body_start_idx] + self._instructions[for_iter_idx:loop_body_start_idx] ) - # 5. call loop body - # 5.1 load loop body + # 6.3 load loop body, prepare inputs and call self._graph.pycode_gen.gen_load_object( loop_body_fn, loop_body_fn.__code__.co_name ) - # 5.2 load loop body inputs for name in loop_body_inputs[:-1]: self._graph.pycode_gen.gen_load(name) - # 5.3 load break flag + # this is the _break_flag self._graph.pycode_gen.gen_load_const(True) - # 5.4 call loop body self._graph.pycode_gen.gen_call_function( argc=loop_body_fn.__code__.co_argcount ) - # 5.5 unpack and store retval, keep break_flag in stack - self._graph.pycode_gen.gen_unpack_sequence(len(loop_body_inputs)) + # 7. unpack and update changed vars, keep break_flag in stack + self._graph.pycode_gen.gen_unpack_sequence(len(loop_body_outputs)) - for name in loop_body_inputs[:-1]: + for name in loop_body_outputs[:-1]: self._graph.pycode_gen.gen_store(name, self._code) - # 6. add jump if break + # 8. create the tail of a for loop, jump back to FOR_ITER + # and process case if break jump_if_break = self._graph.pycode_gen.gen_pop_jump( direction=JumpDirection.FORWARD, suffix=PopJumpCond.FALSE ) - # 7. jump back to FOR_ITER self._graph.pycode_gen.gen_jump( for_iter, direction=JumpDirection.BACKWARD ) - nop = self._graph.pycode_gen._add_instr("NOP") + nop = self._graph.pycode_gen.add_instr("NOP") for_iter.jump_to = nop jump_if_break.jump_to = nop - # 8. call after_loop_fn - self._graph.pycode_gen.gen_load_object( - after_loop_fn, after_loop_fn.__code__.co_name - ) + # 9. prepare inputs and call after_loop_fn + if after_loop_fn is not None: + self._graph.pycode_gen.gen_load_object( + after_loop_fn, after_loop_fn.__code__.co_name + ) - for stack_arg in self.stack: - var_loader.load(stack_arg) - for name in fn_inputs: - if not self.has_var(name) and name not in undefined_names: - undefined_names.add(name) - self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) - self._graph.pycode_gen.gen_store(name, self._code) - self._graph.pycode_gen.gen_load(name) + for stack_arg in self.stack: + var_loader.load(stack_arg) - self._graph.pycode_gen.gen_call_function( - argc=after_loop_fn.__code__.co_argcount - ) + for name in after_loop_fn_inputs: + self._graph.pycode_gen.gen_load(name) + + self._graph.pycode_gen.gen_call_function( + argc=after_loop_fn.__code__.co_argcount + ) + # return what after_loop_fn return self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() self.guard_fn = self._graph.guard_fn @@ -2000,135 +2098,95 @@ def _inline_call_for_loop( self, iterator: VariableBase, for_iter: Instruction ): assert for_iter.jump_to is not None - pycode_gen = PyCodeGen(self._frame) - origin_instrs = get_instructions(pycode_gen._origin_code) + # 1. analyse input and output start_idx = self.indexof(for_iter) end_idx = self.indexof(for_iter.jump_to) - all_used_vars = analysis_used_names_with_space( - origin_instrs, start_idx, end_idx + read_names, write_names = analysis_used_names( + self._instructions, start_idx, end_idx ) - inputs = [ - k - for k, v in all_used_vars.items() - if v in (Space.locals, Space.cells) - ] + [iterator.id] + # why add write_names as input? check case in test/sot/test_12_for_loop.py + # test_for_without_zero_iter + input_var_names = self._find_names_in_space( + read_names | write_names, (Space.locals, Space.cells) + ) + [iterator.id] + output_var_names = list(write_names) + [iterator.id] - # 1. load iter - pycode_gen.gen_load_fast(iterator.id) + # 2. create inline call loop fn + def create_inline_call_fn(): + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) - # 2. copy main logic - pycode_gen.extend_instrs(origin_instrs[start_idx:end_idx]) + pycode_gen.set_function_inputs(input_var_names, stack_size=0) - # 3. add break, continue marker and relocate jump - for_iter_instr = origin_instrs[start_idx] - assert for_iter_instr.jump_to is not None - out_loop_instr = for_iter_instr.jump_to + # 2.1. load iter, it is a input of loop fn + pycode_gen.gen_load_fast(iterator.id) - pycode_gen.gen_jump(out_loop_instr, direction=JumpDirection.FORWARD) - nop_for_continue = pycode_gen._add_instr("NOP") + # 2.2. copy main logic + pycode_gen.extend_instrs(origin_instrs[start_idx:end_idx]) - jump = pycode_gen.gen_jump( - for_iter_instr, direction=JumpDirection.BACKWARD - ) + # 2.3. add break, continue marker and relocate jump + for_iter_instr = origin_instrs[start_idx] + assert for_iter_instr.jump_to is not None + out_loop_instr = for_iter_instr.jump_to - nop_for_break = pycode_gen._add_instr("NOP") + pycode_gen.gen_jump(out_loop_instr, direction=JumpDirection.FORWARD) + nop_for_continue = pycode_gen.add_instr("NOP") - for instr in pycode_gen._instructions: - if instr.jump_to == for_iter_instr: - instr.jump_to = nop_for_continue + jump = pycode_gen.gen_jump( + for_iter_instr, direction=JumpDirection.BACKWARD + ) - if ( - instr.jump_to in origin_instrs - and origin_instrs.index(instr.jump_to) >= end_idx - ): - instr.jump_to = nop_for_break + nop_for_break = pycode_gen.add_instr("NOP") - jump.jump_to = for_iter_instr - pycode_gen.gen_outputs_and_return(inputs) - inline_call_fn = pycode_gen.create_fn_with_inputs(inputs) + # 2.4. relocate jumps + for instr in pycode_gen._instructions: + if instr.jump_to == for_iter_instr: + instr.jump_to = nop_for_continue - log( - 3, - f"[Resumed Function]: Inline call for loop function {inline_call_fn.__code__.co_name}\n", - ) - log_do(3, lambda: dis.dis(inline_call_fn)) + if ( + instr.jump_to in origin_instrs + and origin_instrs.index(instr.jump_to) >= end_idx + ): + instr.jump_to = nop_for_break + + jump.jump_to = for_iter_instr + + pycode_gen.set_function_outputs(output_var_names) + inline_call_fn = pycode_gen.create_function() - # TODO: update globals builtins + log( + 3, + f"[Resumed Function]: Inline call for loop function {inline_call_fn.__code__.co_name}\n", + ) + log_do(3, lambda: dis.dis(inline_call_fn)) + + return inline_call_fn + + inline_call_fn = create_inline_call_fn() + + # 3. create function variable fn = UserDefinedFunctionVariable( inline_call_fn, self._graph, DanglingTracker(), ) + # 4. prepare input datas and call input_vars = [ - self.get_var(name) - if self.has_var(name, all_used_vars[name]) - else SotUndefinedVar() - for name in inputs[:-1] + self.get_var(name, allow_undefined=True) + for name in input_var_names[:-1] ] + [iterator] + ret = fn(*input_vars) - # slice_variable is [:-1] + + # 5. update changed vars slice_const = slice(None, -1, None) slice_variable = SliceVariable( slice_const, self._graph, ConstTracker(slice_const) ) - for name, val in zip(inputs[:-1], ret[slice_variable]): - self._locals[name] = val - - def FOR_ITER(self, instr): - iterator = self.stack.pop() - backup_iter_idx = None - - start = self.indexof(instr) - end = self.indexof(instr.jump_to) - for i in range(start, end): - if self._instructions[i].opname == "RETURN_VALUE": - raise FallbackError("Found RETURN_VALUE in for loop body.") - - self._graph.add_global_guarded_variable(iterator) - - try: - if not isinstance(iterator, SequenceIterVariable): - raise BreakGraphError( - f"Can not simulate iterator of {type(iterator)}." - ) - - backup_iter_idx = iterator.idx - - self._inline_call_for_loop(iterator, instr) - self._lasti = self.indexof(instr.jump_to) - next_instr = self._instructions[self._lasti] - self._lasti += int(next_instr.opname == 'END_FOR') - except BreakGraphError as e: - log(3, f"[BreakGraph] FOR_ITER sim for loop failed for: {e}\n") - if backup_iter_idx: - iterator.idx = backup_iter_idx - self._graph.remove_global_guarded_variable(iterator) - self.stack.push(iterator) - self._break_graph_when_for_loop(iterator, instr) - return Stop(state="BreakGraph") - - def RETURN_VALUE(self, instr: Instruction): - assert ( - len(self.stack) == 1 - ), f"Stack must have one element, but get {len(self.stack)} elements." - ret_val = self.stack.pop() - return self.compile_return(ret_val) - - def RETURN_CONST(self, instr: Instruction): - ret_const = self._co_consts[instr.arg] - return self.compile_return(ret_const) - def compile_return(self, ret_val): - compile_fn = self._graph.get_compiled_fn(ret_val) - if compile_fn.graph_size() < ENV_MIN_GRAPH_SIZE.get(): - self.new_code = None - else: - self._graph.start_compile(ret_val) - self._graph.pycode_gen.gen_return() - self.new_code = self._graph.pycode_gen.gen_pycode() - self.guard_fn = self._graph.guard_fn - return Stop(state="Return") + for name, var in zip(output_var_names[:-1], ret[slice_variable]): + self.set_var(name, var) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py index 3832d05f044487..306166aa7d872c 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -299,16 +299,6 @@ def _break_graph_when_if(self, result, instr: Instruction): "OpcodeInlineExecutor want break graph when simulate `if`." ) - def _create_resume_fn(self, index: int, stack_size: int = 0): - """ - Helper method to create a resume function for the executor. - - Args: - index (int): The index of the instruction to resume execution from. - stack_size (int, optional): The size of the stack. Defaults to 0. - """ - raise BreakGraphError("_create_resume_fn.") - def FOR_ITER(self, instr: Instruction): iterator = self.stack.top assert isinstance(iterator, IterVariable) diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index 69e174818d6627..2ada3f7228f114 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -39,11 +39,9 @@ no_eval_frame, ) from ..instruction_utils import ( - analysis_inputs, apply_instr_pass, calc_stack_effect, gen_instr, - get_instructions, instrs_info, modify_instrs, modify_vars, @@ -437,6 +435,42 @@ def __init__( self.hooks = [] if self.disable_eval_frame: self.gen_disable_eval_frame() + self.fn_name = ResumeFnNameFactory().next() + + def set_function_inputs(self, inputs: list[str], stack_size: int): + stack_arg_str = self.fn_name + '_stack_{}' + + self._code_options['co_argcount'] = len(inputs) + stack_size + self._code_options['co_varnames'] = list( + [stack_arg_str.format(i) for i in range(stack_size)] + + inputs + + [ + var_name + for var_name in self._origin_code.co_varnames + if var_name not in inputs + ] + ) + + self._instructions.extend( + [ + gen_instr('LOAD_FAST', argval=stack_arg_str.format(i)) + for i in range(stack_size) + ] + ) + + def set_function_outputs(self, outputs: list[str]): + for name in outputs: + self.gen_load(name) + self.gen_build_tuple(len(outputs)) + self.gen_return() + + def create_function(self) -> types.FunctionType: + self.update_code_name(self.fn_name, is_resumed_fn=True) + new_code = self.gen_pycode() + if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: + raise FallbackError("Break graph in closure is not support.") + fn = types.FunctionType(new_code, self._f_globals, new_code.co_name) + return fn def insert_prefix_instructions(self): """ @@ -509,58 +543,6 @@ def gen_pycode(self) -> types.CodeType: return new_code - def gen_resume_fn_at( - self, index: int, stack_size: int - ) -> tuple[None | types.FunctionType, OrderedSet[str]]: - """ - Generates a resume function at the specified index in the instruction list. - - Args: - index (int): The index in the instruction list to generate the resume function. - stack_size (int): The size of the stack. Defaults to 0. - - Returns: - tuple: The resume function object and the inputs to the function. - - """ - - self._instructions = get_instructions(self._origin_code) - # TODO(dev): could give an example code here? - if self._instructions[index].opname == 'RETURN_VALUE': - return None, OrderedSet() - inputs = analysis_inputs(self._instructions, index) - fn_name = ResumeFnNameFactory().next() - stack_arg_str = fn_name + '_stack_{}' - - self._instructions = ( - [ - gen_instr('LOAD_FAST', argval=stack_arg_str.format(i)) - for i in range(stack_size) - ] - + [gen_instr('JUMP_FORWARD', jump_to=self._instructions[index])] - + self._instructions - ) - - self._code_options['co_argcount'] = len(inputs) + stack_size - # inputs should be at the front of the co_varnames - self._code_options['co_varnames'] = list( - [stack_arg_str.format(i) for i in range(stack_size)] - + list(inputs) - + [ - var_name - for var_name in self._code_options['co_varnames'] - if var_name not in inputs - ] - ) - - self.update_code_name(fn_name, is_resumed_fn=True) - new_code = self.gen_pycode() - if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: - raise FallbackError("Break graph in closure is not support.") - fn = types.FunctionType(new_code, self._f_globals, new_code.co_name) - - return fn, inputs - @cached_property def global_null_variable(self): from .variables.basic import NullVariable @@ -593,39 +575,6 @@ def gen_enable_eval_frame(self): self.gen_call_function(1) self.gen_pop_top() - def gen_outputs_and_return(self, outputs): - for name in outputs: - self.gen_load(name) - self.gen_build_tuple(len(outputs)) - self.gen_return() - - def create_fn_with_inputs(self, inputs: list) -> types.FunctionType: - """ - Creates a function with specific input and output variables. - - Args: - inputs (list): The input variables. - - Returns: - function: The created function object. - """ - self._code_options['co_argcount'] = len(inputs) - self._code_options['co_varnames'] = list( - list(inputs) - + [ - var_name - for var_name in self._origin_code.co_varnames - if var_name not in inputs - ] - ) - fn_name = ResumeFnNameFactory().next() - self.update_code_name(fn_name, is_resumed_fn=True) - new_code = self.gen_pycode() - if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: - raise FallbackError("Break graph in closure is not support.") - fn = types.FunctionType(new_code, self._f_globals, new_code.co_name) - return fn - def gen_load_const(self, value: Any): """ Generates instructions to load a constant value. @@ -636,7 +585,7 @@ def gen_load_const(self, value: Any): if not list_contain_by_id(self._code_options["co_consts"], value): self._code_options["co_consts"].append(value) idx = list_find_index_by_id(self._code_options["co_consts"], value) - self._add_instr("LOAD_CONST", arg=idx, argval=value) + return self.add_instr("LOAD_CONST", arg=idx, argval=value) def gen_print_log(self, message): """print a log""" @@ -745,7 +694,7 @@ def gen_load_global(self, name, push_null=False): idx <<= 1 if push_null: idx |= 1 - self._add_instr("LOAD_GLOBAL", arg=idx, argval=name) + return self.add_instr("LOAD_GLOBAL", arg=idx, argval=name) def gen_load_object(self, obj, obj_name: str, push_null: bool = True): """ @@ -758,14 +707,14 @@ def gen_load_object(self, obj, obj_name: str, push_null: bool = True): if obj_name not in self._f_globals: self._f_globals[obj_name] = obj - self.gen_load_global(obj_name, push_null=push_null) + return self.gen_load_global(obj_name, push_null=push_null) def gen_load_null_variable(self): """ Generate the bytecode for loading a null variable. """ null_var = self.global_null_variable - self.gen_load_object(null_var, "___null_var", push_null=False) + return self.gen_load_object(null_var, "___null_var", push_null=False) def gen_load_fast(self, name): """ @@ -777,7 +726,7 @@ def gen_load_fast(self, name): if name not in self._code_options["co_varnames"]: self._code_options["co_varnames"].append(name) idx = self._code_options["co_varnames"].index(name) - self._add_instr("LOAD_FAST", arg=idx, argval=name) + return self.add_instr("LOAD_FAST", arg=idx, argval=name) def gen_load_deref(self, name): if name not in self.cell_free_storage: @@ -791,7 +740,7 @@ def gen_load_deref(self, name): ).index(name) else: idx = self.cell_free_storage.index(name) - self._add_instr("LOAD_DEREF", arg=idx, argval=name) + return self.add_instr("LOAD_DEREF", arg=idx, argval=name) def gen_load_attr(self, name: str): if name not in self._code_options["co_names"]: @@ -799,49 +748,49 @@ def gen_load_attr(self, name: str): idx = self._code_options["co_names"].index(name) if sys.version_info >= (3, 12): idx <<= 1 - self._add_instr("LOAD_ATTR", arg=idx, argval=name) + return self.add_instr("LOAD_ATTR", arg=idx, argval=name) def gen_store_attr(self, name: str): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) idx = self._code_options["co_names"].index(name) - self._add_instr("STORE_ATTR", arg=idx, argval=name) + return self.add_instr("STORE_ATTR", arg=idx, argval=name) def gen_delete_attr(self, name: str): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) idx = self._code_options["co_names"].index(name) - self._add_instr("DELETE_ATTR", arg=idx, argval=name) + return self.add_instr("DELETE_ATTR", arg=idx, argval=name) def gen_load_method(self, name: str): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) idx = self._code_options["co_names"].index(name) - self._add_instr("LOAD_METHOD", arg=idx, argval=name) + return self.add_instr("LOAD_METHOD", arg=idx, argval=name) def gen_delete_global(self, name: str): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) idx = self._code_options["co_names"].index(name) - self._add_instr("DELETE_GLOBAL", arg=idx, argval=name) + return self.add_instr("DELETE_GLOBAL", arg=idx, argval=name) def gen_import_name(self, name: str): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) idx = self._code_options["co_names"].index(name) - self._add_instr("IMPORT_NAME", arg=idx, argval=name) + return self.add_instr("IMPORT_NAME", arg=idx, argval=name) def gen_store_fast(self, name): if name not in self._code_options["co_varnames"]: self._code_options["co_varnames"].append(name) idx = self._code_options["co_varnames"].index(name) - self._add_instr("STORE_FAST", arg=idx, argval=name) + return self.add_instr("STORE_FAST", arg=idx, argval=name) def gen_store_global(self, name): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) idx = self._code_options["co_names"].index(name) - self._add_instr("STORE_GLOBAL", arg=idx, argval=name) + return self.add_instr("STORE_GLOBAL", arg=idx, argval=name) def gen_store_deref(self, name): if name not in self.cell_free_storage: @@ -855,50 +804,50 @@ def gen_store_deref(self, name): ).index(name) else: idx = self.cell_free_storage.index(name) - self._add_instr("STORE_DEREF", arg=idx, argval=name) + return self.add_instr("STORE_DEREF", arg=idx, argval=name) def gen_store_subscr(self): - self._add_instr("STORE_SUBSCR") + return self.add_instr("STORE_SUBSCR") def gen_subscribe(self): - self._add_instr("BINARY_SUBSCR") + return self.add_instr("BINARY_SUBSCR") def gen_build_tuple(self, count): - self._add_instr("BUILD_TUPLE", arg=count, argval=count) + return self.add_instr("BUILD_TUPLE", arg=count, argval=count) def gen_build_list(self, count): - self._add_instr("BUILD_LIST", arg=count, argval=count) + return self.add_instr("BUILD_LIST", arg=count, argval=count) def gen_build_map(self, count): - self._add_instr("BUILD_MAP", arg=count, argval=count) + return self.add_instr("BUILD_MAP", arg=count, argval=count) def gen_build_slice(self, argc): - self._add_instr("BUILD_SLICE", arg=argc, argval=argc) + return self.add_instr("BUILD_SLICE", arg=argc, argval=argc) def gen_unpack_sequence(self, count): - self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count) + return self.add_instr("UNPACK_SEQUENCE", arg=count, argval=count) def gen_call_function(self, argc=0): if sys.version_info >= (3, 11): if sys.version_info < (3, 12): - self._add_instr("PRECALL", arg=argc, argval=argc) - self._add_instr("CALL", arg=argc, argval=argc) + self.add_instr("PRECALL", arg=argc, argval=argc) + self.add_instr("CALL", arg=argc, argval=argc) else: - self._add_instr("CALL_FUNCTION", arg=argc, argval=argc) + self.add_instr("CALL_FUNCTION", arg=argc, argval=argc) def gen_call_function_ex(self, has_kwargs): flag = 0 if has_kwargs: flag |= CALL_FUNCTION_EX_FLAG.CFE_HAS_KWARGS - self._add_instr("CALL_FUNCTION_EX", arg=flag, argval=flag) + self.add_instr("CALL_FUNCTION_EX", arg=flag, argval=flag) def gen_call_method(self, argc=0): if sys.version_info >= (3, 11): if sys.version_info < (3, 12): - self._add_instr("PRECALL", arg=argc, argval=argc) - self._add_instr("CALL", arg=argc, argval=argc) + self.add_instr("PRECALL", arg=argc, argval=argc) + self.add_instr("CALL", arg=argc, argval=argc) else: - self._add_instr("CALL_METHOD", arg=argc, argval=argc) + self.add_instr("CALL_METHOD", arg=argc, argval=argc) def gen_kw_names(self, kw_names: tuple[str, ...] | None): if kw_names is None: @@ -908,22 +857,22 @@ def gen_kw_names(self, kw_names: tuple[str, ...] | None): if kw_names not in self._code_options["co_consts"]: self._code_options["co_consts"].append(kw_names) idx = self._code_options["co_consts"].index(kw_names) - self._add_instr("KW_NAMES", arg=idx, argval=kw_names) + self.add_instr("KW_NAMES", arg=idx, argval=kw_names) def gen_pop_top(self): - self._add_instr("POP_TOP") + return self.add_instr("POP_TOP") def gen_rot_n(self, n): if n <= 1: return if sys.version_info >= (3, 11): for i in range(n, 1, -1): - self._add_instr("SWAP", arg=i) + self.add_instr("SWAP", arg=i) elif sys.version_info >= (3, 10): - self._add_instr("ROT_N", arg=n) + self.add_instr("ROT_N", arg=n) else: if n <= 4: - self._add_instr("ROT_" + ["TWO", "THREE", "FOUR"][n - 2]) + self.add_instr("ROT_" + ["TWO", "THREE", "FOUR"][n - 2]) else: def rot_n_fn(n): @@ -937,7 +886,7 @@ def rot_n_fn(n): self.gen_build_tuple(n) self.gen_load_const(rot_n_fn(n)) self.gen_rot_n(2) - self._add_instr("CALL_FUNCTION_EX", arg=0) + self.add_instr("CALL_FUNCTION_EX", arg=0) self.gen_unpack_sequence(n) def gen_shift_n(self, s: int, n: int): @@ -970,7 +919,7 @@ def gen_shift_n(self, s: int, n: int): # NOTE: s=-1, n=3 [1,2,3,4,5] -> [1,2,4,5,3] if s == -1: for i in range(2, n + 1): - self._add_instr("SWAP", arg=i) + self.add_instr("SWAP", arg=i) else: self.gen_shift_n(-1, n) self.gen_shift_n(s + 1, n) @@ -981,7 +930,7 @@ def gen_shift_n(self, s: int, n: int): def gen_swap(self, n): if sys.version_info >= (3, 11): - self._add_instr("SWAP", arg=n) + self.add_instr("SWAP", arg=n) else: raise NotImplementedError("swap is not supported before python3.11") @@ -992,9 +941,9 @@ def gen_jump( direction: JumpDirection = JumpDirection.FORWARD, ) -> Instruction: if sys.version_info >= (3, 11): - return self._add_instr(f"JUMP_{direction.value}", jump_to=jump_to) + return self.add_instr(f"JUMP_{direction.value}", jump_to=jump_to) else: - return self._add_instr("JUMP_ABSOLUTE", jump_to=jump_to) + return self.add_instr("JUMP_ABSOLUTE", jump_to=jump_to) def gen_pop_jump( self, @@ -1004,33 +953,33 @@ def gen_pop_jump( suffix: PopJumpCond = PopJumpCond.NONE, ) -> Instruction: if sys.version_info >= (3, 11): - return self._add_instr( + return self.add_instr( f"POP_JUMP_{direction.value}_IF_{suffix.value}", jump_to=jump_to ) else: - return self._add_instr( + return self.add_instr( f"POP_JUMP_IF_{suffix.value}", jump_to=jump_to ) def gen_return(self): - self._add_instr("RETURN_VALUE") + return self.add_instr("RETURN_VALUE") def gen_get_iter(self): - self._add_instr("GET_ITER") + return self.add_instr("GET_ITER") def gen_operator_only(self, op_name): """ only generator operator instruction, do nothing for operands. """ - self._add_instr(op_name) + return self.add_instr(op_name) def gen_operator(self, op_name): """ only generator operator instruction, do nothing for operands. """ - self._add_instr(op_name) + return self.add_instr(op_name) def gen_compare(self, cmp_op): """ @@ -1039,9 +988,9 @@ def gen_compare(self, cmp_op): """ if sys.version_info >= (3, 12): cmp_op <<= 4 - self._add_instr("COMPARE_OP", cmp_op) + return self.add_instr("COMPARE_OP", cmp_op) - def _add_instr(self, *args, **kwargs): + def add_instr(self, *args, **kwargs): instr = gen_instr(*args, **kwargs) self._instructions.append(instr) return instr diff --git a/python/paddle/jit/sot/opcode_translator/executor/tracker.py b/python/paddle/jit/sot/opcode_translator/executor/tracker.py index fd7168f4e5957f..51d21a5572129f 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/tracker.py +++ b/python/paddle/jit/sot/opcode_translator/executor/tracker.py @@ -393,7 +393,7 @@ def __init__(self, iter_source: VariableBase): def gen_instructions(self, codegen: PyCodeGen): self.iter_source.tracker.gen_instructions(codegen) - codegen._add_instr("GET_ITER") + codegen.add_instr("GET_ITER") def trace_value_from_frame(self): iter_source_tracer = self.iter_source.tracker.trace_value_from_frame() diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py index 0b9429e078ec71..833fd3c207e883 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py @@ -15,6 +15,7 @@ from .instruction_pass import apply_instr_pass # noqa: F401 from .instruction_utils import ( # noqa: F401 Instruction, + Space, calc_offset_from_bytecode_offset, calc_stack_effect, convert_instruction, @@ -29,7 +30,5 @@ reset_offset, ) from .opcode_analysis import ( # noqa: F401 - Space, - analysis_inputs, - analysis_used_names_with_space, + analysis_used_names, ) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py index 8725aa55c32138..5b0cc17fc808f2 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py @@ -90,6 +90,8 @@ def find_related_local_opcodes(instrs, code_options): if len(stack) > 0 and stack[-1] is not None: opcode_pairs.append((stack[-1], instr)) stack.pop() + elif "ROT" in instr.opname: + return [] else: try: pop_n, push_n = StackAnalyser().stack_effect(instr) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py index 05e6dcfc91e7db..2965c8e6bc056e 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -17,6 +17,7 @@ import dataclasses import dis import sys +from enum import Enum from typing import TYPE_CHECKING, Any from ...utils import InnerError @@ -410,3 +411,10 @@ def calc_stack_effect(instr: Instruction, *, jump: bool | None = None) -> int: assert instr.arg is not None return -instr.arg - 1 return dis.stack_effect(instr.opcode, instr.arg, jump=jump) + + +class Space(Enum): + locals = 1 + globals = 2 + cells = 3 + not_found = 4 diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py index f0211167f44498..2e8ded5d2ac5e4 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -15,11 +15,9 @@ from __future__ import annotations import dataclasses -from enum import Enum from paddle.jit.utils import OrderedSet -from ...utils import InnerError from .instruction_utils import Instruction from .opcode_info import ALL_JUMP, HAS_FREE, HAS_LOCAL, UNCONDITIONAL_JUMP @@ -30,6 +28,11 @@ class State: writes: OrderedSet[str] visited: OrderedSet[int] + def __or__(self, other): + reads = self.reads | other.reads + writes = self.writes | other.writes + return State(reads, writes, OrderedSet()) + def is_read_opcode(opname): if opname in [ @@ -63,7 +66,7 @@ def is_write_opcode(opname): return False -def analysis_inputs( +def analysis_used_names( instructions: list[Instruction], current_instr_idx: int, stop_instr_idx: int | None = None, @@ -97,7 +100,7 @@ def walk(state: State, start: int) -> OrderedSet[str]: end = len(instructions) if stop_instr_idx is None else stop_instr_idx for i in range(start, end): if i in state.visited: - return state.reads + return state state.visited.add(i) instr = instructions[i] @@ -116,104 +119,12 @@ def walk(state: State, start: int) -> OrderedSet[str]: not_jump_branch = ( fork(state, i, False, target_idx) if instr.opname not in UNCONDITIONAL_JUMP - else OrderedSet() - ) - return jump_branch | not_jump_branch - elif instr.opname == "RETURN_VALUE": - return state.reads - return state.reads - - return walk(root_state, current_instr_idx) - - -@dataclasses.dataclass -class SpaceState: - reads: dict[str, Space] - writes: dict[str, Space] - visited: OrderedSet[int] - - def __or__(self, other): - reads = {} - reads.update(other.reads) - reads.update(self.reads) - writes = {} - writes.update(other.writes) - writes.update(self.writes) - return SpaceState(reads, writes, OrderedSet()) - - -class Space(Enum): - locals = 1 - globals = 2 - cells = 3 - all = 4 - - -def get_space(opname: str): - if "FAST" in opname: - return Space.locals - elif "GLOBAL" in opname: - return Space.globals - elif "DEREF" in opname or "CLOSURE" in opname: - return Space.cells - elif "NAME" in opname: - return Space.all - else: - raise InnerError(f"Unknown space for {opname}") - - -def analysis_used_names_with_space( - instructions: list[Instruction], - start_instr_idx: int, - stop_instr_idx: int | None = None, -): - root_state = SpaceState({}, {}, OrderedSet()) - - def fork( - state: SpaceState, start: int, jump: bool, jump_target: int - ) -> SpaceState: - new_start = start + 1 if not jump else jump_target - new_state = SpaceState( - dict(state.reads), - dict(state.writes), - OrderedSet(state.visited), - ) - return walk(new_state, new_start) - - def walk(state: SpaceState, start: int) -> SpaceState: - end = len(instructions) if stop_instr_idx is None else stop_instr_idx - for i in range(start, end): - if i in state.visited: - return state - state.visited.add(i) - - instr = instructions[i] - if instr.opname in HAS_LOCAL | HAS_FREE: - if is_read_opcode(instr.opname) and instr.argval not in ( - state.writes - ): - space = get_space(instr.opname) - state.reads[instr.argval] = space - elif is_write_opcode(instr.opname): - space = get_space(instr.opname) - state.writes[instr.argval] = space - elif instr.opname in ALL_JUMP: - assert instr.jump_to is not None - target_idx = instructions.index(instr.jump_to) - # Fork to two branches, jump or not - jump_branch = fork(state, i, True, target_idx) - not_jump_branch = ( - fork(state, i, False, target_idx) - if instr.opname not in UNCONDITIONAL_JUMP - else SpaceState({}, {}, OrderedSet()) + else State(OrderedSet(), OrderedSet(), OrderedSet()) ) return jump_branch | not_jump_branch elif instr.opname == "RETURN_VALUE": return state return state - state = walk(root_state, start_instr_idx) - all_used_vars = {} - all_used_vars.update(state.writes) - all_used_vars.update(state.reads) - return all_used_vars + state = walk(root_state, current_instr_idx) + return state.reads, state.writes diff --git a/test/sot/test_11_jumps.py b/test/sot/test_11_jumps.py index 80fa1f4a4eb02b..6073766e8b60fb 100644 --- a/test/sot/test_11_jumps.py +++ b/test/sot/test_11_jumps.py @@ -114,5 +114,17 @@ def test_breakgraph(self): self.assert_results(pop_jump_if_not_none, true_tensor, a) +def new_var_in_if(): + x = paddle.to_tensor(1) + if x > 0: + y = 1 + return y + + +class TestCreateVarInIf(TestCaseBase): + def test_case(self): + self.assert_results(new_var_in_if) + + if __name__ == "__main__": unittest.main() diff --git a/test/sot/test_analysis_inputs.py b/test/sot/test_analysis_inputs.py index 20b32c2225324f..880de6060d4009 100644 --- a/test/sot/test_analysis_inputs.py +++ b/test/sot/test_analysis_inputs.py @@ -20,7 +20,7 @@ import paddle from paddle.jit.sot.opcode_translator.instruction_utils import ( - analysis_inputs, + analysis_used_names, calc_offset_from_bytecode_offset, get_instructions, ) @@ -36,12 +36,12 @@ def assert_inputs_equals(instruction_offset: int, expected_inputs: set[str]): current_instr_idx = calc_offset_from_bytecode_offset( test_frame.f_lasti + 2, instructions ) - actual_inputs = analysis_inputs( + reads, writes = analysis_used_names( instructions, current_instr_idx + instruction_offset ) assert ( - set(actual_inputs) == expected_inputs - ), f"actual_inputs: {actual_inputs}, expected_inputs: {expected_inputs}" + set(reads) == expected_inputs + ), f"actual_inputs: {reads}, expected_inputs: {expected_inputs}" def case1(x):