Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 33 additions & 38 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...utils import (
ENV_SHOW_TRACKERS,
NameGenerator,
SotUndefinedVar,
inner_error_default_handler,
is_inplace_api,
is_paddle_api,
Expand Down Expand Up @@ -140,6 +141,20 @@ def get_params_and_non_param_symbol(*args, **kwargs):
return params, non_params


class VariableLoader:
def __init__(self, store_var_info, pycode_gen):
self._store_var_info = store_var_info
self._pycode_gen: PyCodeGen = pycode_gen

def load(self, var):
if var is SotUndefinedVar():
self._pycode_gen.gen_load_const(SotUndefinedVar())
elif isinstance(var, NullVariable):
var.reconstruct(self._pycode_gen)
else:
self._pycode_gen.gen_load(self._store_var_info[var.id])


class FunctionGraph:
"""
A Graph representation corresponding to each FunctionFrame
Expand Down Expand Up @@ -281,17 +296,6 @@ def guard_fn(self) -> Guard:
return make_guard(guards)

def _restore_origin_opcode(self, stack_vars, store_var_info, instr_idx):
class VariableLoader:
def __init__(self, store_var_info, pycode_gen):
self._store_var_info = store_var_info
self._pycode_gen: PyCodeGen = pycode_gen

def load(self, var):
if isinstance(var, NullVariable):
var.reconstruct(self._pycode_gen)
return
self._pycode_gen.gen_load(self._store_var_info[var.id])

origin_instrs = get_instructions(self.pycode_gen._origin_code)
is_precall = origin_instrs[instr_idx].opname == "PRECALL"
current_idx = instr_idx
Expand All @@ -308,7 +312,7 @@ def load(self, var):
restore_instr_names = restore_instr_names[:-1]

self.pycode_gen.extend_instrs(restore_instrs)
nop = self.pycode_gen._add_instr("NOP")
nop = self.pycode_gen.add_instr("NOP")

for instr in origin_instrs:
if instr.jump_to == origin_instrs[current_idx]:
Expand All @@ -324,46 +328,37 @@ def load(self, var):

name_gen = NameGenerator("__start_compile_saved_orig_")

# here is not update changed values, it just give names to stack vars
# and want keep same interface as _build_compile_fn_with_name_store
for var in stack_vars[::-1]:
store_var_info[var.id] = name_gen.next()
self.pycode_gen.gen_store_fast(store_var_info[var.id])
if store_var_info[var.id] is None:
store_var_info[var.id] = name_gen.next()
self.pycode_gen.gen_store_fast(store_var_info[var.id])
else:
self.pycode_gen.gen_store(
store_var_info[var.id], self.pycode_gen._origin_code
)

return VariableLoader(store_var_info, self.pycode_gen)

def _build_compile_fn_with_name_store(self, to_store_vars):
class VariableLoader:
def __init__(self, index_for_load, pycode_gen):
self._index_for_load = index_for_load
self._pycode_gen: PyCodeGen = pycode_gen

def load(self, var, allow_push_null=True):
if isinstance(var, NullVariable):
var.reconstruct(self._pycode_gen)
return
self._pycode_gen.gen_load(self._index_for_load[var.id])

def _build_compile_fn_with_name_store(self, to_store_vars, store_var_info):
# var_id -> local_name mapping
index_for_load = {}
to_store_vars = list(
filter(lambda x: not isinstance(x, NullVariable), to_store_vars)
)
self.start_compile(*to_store_vars)
name_gen = NameGenerator("__start_compile_saved_")

for var in to_store_vars[::-1]:
index_for_load[var.id] = name_gen.next()

def _log_fn():
print(
f"[StartCompile] saved var: {index_for_load[var.id]} = ",
var,
if store_var_info[var.id] is None:
store_var_info[var.id] = name_gen.next()
self.pycode_gen.gen_store_fast(store_var_info[var.id])
else:
self.pycode_gen.gen_store(
store_var_info[var.id], self.pycode_gen._origin_code
)

log_do(4, _log_fn)

self.pycode_gen.gen_store_fast(index_for_load[var.id])

return VariableLoader(index_for_load, self.pycode_gen)
return VariableLoader(store_var_info, self.pycode_gen)

def get_compiled_fn(self, *ret_vars):
ret_items = [
Expand Down
Loading