Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,39 @@ def infer_meta_for_layer(layer, *args, **kwargs):
) = layer.forward.get_concrete_program(*args_, **kwargs_)

out = partial_program_layer._restore_out(
paddle.utils.flatten(
convert_variable_to_meta_info(concrete_program.outputs)
)
[
x
for x in paddle.utils.flatten(
convert_variable_to_meta_info(concrete_program.outputs)
)
if isinstance(x, MetaInfo)
]
)
layer.forward.rollback()
return out


def ast_infer_meta(static_function, *args, **kwargs):
args_, kwargs_ = convert_meta_to_input_spec((args, kwargs))

(
concrete_program,
partial_program_layer,
) = static_function.get_concrete_program(*args_, **kwargs_)

out = partial_program_layer._restore_out(
[
x
for x in paddle.utils.flatten(
convert_variable_to_meta_info(concrete_program.outputs)
)
if isinstance(x, MetaInfo)
]
)

return out


@Singleton
class SpecialInferMeta:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/sot/opcode_translator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# limitations under the License.

from .skip_files import setup_skip_files
from .transform import eval_frame_callback # noqa: F401
from .eval_frame_callback import eval_frame_callback # noqa: F401

setup_skip_files()
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,23 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode:
)
log_do(4, partial(print_locals, frame))

log_format(3, "[transform] OriginCode: {}\n", frame.f_code.co_name)
log_format(
3, "[eval_frame_callback] OriginCode: {}\n", frame.f_code.co_name
)
log_do(3, lambda: dis.dis(frame.f_code))

custom_code = OpcodeExecutorCache()(frame, **kwargs)

if custom_code.code is None:
log_format(
3,
"[transform] NewCode (same as origin code): {}\n",
"[eval_frame_callback] NewCode (same as origin code): {}\n",
frame.f_code.co_name,
)
else:
log_format(
3,
"[transform] NewCode: {}\n",
"[eval_frame_callback] NewCode: {}\n",
custom_code.code.co_name,
)
log_do(3, lambda: dis.dis(custom_code.code))
Expand Down
143 changes: 106 additions & 37 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from functools import cached_property
from typing import Any, Callable

from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo
from ...infer_meta import (
InferMetaCache,
LayerInferMetaCache,
MetaInfo,
ast_infer_meta,
)
from ...profiler import EventGuard, event_register
from ...symbolic.statement_ir import Reference, Symbol
from ...symbolic.symbolic_context import SymbolicTraceContext
Expand Down Expand Up @@ -56,6 +61,7 @@
)
from .tracker import BuiltinTracker, DummyTracker
from .variables import (
ConstantVariable,
DictVariable,
GlobalVariable,
ListVariable,
Expand Down Expand Up @@ -99,6 +105,18 @@ def func(x):
return map_variables(func, inputs)


def get_symbol_meta_map(inputs):
output = {}

def func(x):
if isinstance(x, TensorVariable):
output[x.get_symbol()] = x.meta
return x

map_variables(func, inputs)
return output


class FunctionGraph:
"""
A Graph representation corresponding to each FunctionFrame
Expand Down Expand Up @@ -129,7 +147,6 @@ def __init__(self, frame, **kwargs):
self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet()
self._print_variables = []
self._inplace_tensors = OrderedSet()
self.build_strategy = kwargs.get('build_strategy', None)
self._kwargs = kwargs

@cached_property
Expand Down Expand Up @@ -291,7 +308,7 @@ def load(self, var):

return VariableLoader(store_var_info, self.pycode_gen)

def _build_compile_fn_with_name_store(self, ret_vars, to_store_vars):
def _build_compile_fn_with_name_store(self, to_store_vars):
class VariableLoader:
def __init__(self, index_for_load, pycode_gen):
self._index_for_load = index_for_load
Expand All @@ -308,7 +325,7 @@ def load(self, var, allow_push_null=True):
to_store_vars = list(
filter(lambda x: not isinstance(x, NullVariable), to_store_vars)
)
self.start_compile(*(ret_vars + to_store_vars))
self.start_compile(*to_store_vars)
name_gen = NameGenerator("__start_compile_saved_")

for var in to_store_vars[::-1]:
Expand All @@ -326,6 +343,22 @@ def _log_fn():

return VariableLoader(index_for_load, self.pycode_gen)

def get_compiled_fn(self, *ret_vars):
ret_items = [
ret_item
for ret_var in ret_vars
for ret_item in ret_var.flatten_items()
]

tensor_items = self._find_tensor_outputs(ret_items)

compiled_fn, _ = self.sir_ctx.compile_fn(
[Symbol(tensor_var.var_name) for tensor_var in tensor_items],
**self._kwargs,
)

return compiled_fn

@event_register("start_compile", event_level=2)
def start_compile(self, *ret_vars: VariableBase):
"""
Expand Down Expand Up @@ -440,36 +473,6 @@ def message_handler(*args, **kwargs):
**kwargs,
)

@staticmethod
def get_opcode_executor_stack():
# NOTE: only for debug.
# dependent on OpcodeExecutor.
from .opcode_executor import OpcodeExecutorBase

if len(OpcodeExecutorBase.call_stack) == 0:
# In test case, we can meet this senario.
return []
current_executor = OpcodeExecutorBase.call_stack[-1]
current_line = current_executor._current_line
filename = current_executor._code.co_filename
source_lines, start_line = inspect.getsourcelines(
current_executor._code
)
# TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph,
# We need to find a way to fix this.
line_idx = min(current_line - start_line, len(source_lines) - 1)
code_line = source_lines[line_idx]
stack = []
stack.append(
' File "{}", line {}, in {}'.format(
filename,
current_line,
current_executor._code.co_name,
)
)
stack.append(f' {code_line}')
return stack

def call_layer(
self,
layer: PaddleLayerVariable,
Expand Down Expand Up @@ -503,14 +506,46 @@ def message_handler(*args, **kwargs):
infer_meta_fn, compute_fn, layer, *args, **kwargs
)

def call_ast(
self,
static_function: tuple,
*args: VariableBase,
**kwargs: VariableBase,
):
"""
call paddle layer, start symbolic trace.

Args:
layer: paddle layer
"""

def compute_fn(static_function, inputs, outputs, stacks):
self.sir_ctx.call_AST(
static_function,
inputs=inputs,
outputs=outputs,
stacks=stacks,
)

def message_handler(*args, **kwargs):
return "Call ast faild"

try:
return inner_error_default_handler(
self.symbolic_call, message_handler
)(ast_infer_meta, compute_fn, static_function, *args, **kwargs)
except Exception as e:
log(3, f"[call AST] {e}")
return None

def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
"""
Using infer_meta_fn and compute_fn convert func to symbolic function.

Args:
infer_meta_fn: function for infer meta, (func, metas, kwmetas) -> output_metas
compute_fn : function for sir compile, (func, input_symbols, outputs_symbols) -> None
func : symbolic function
compute_fn : function for add stmt to sir, (func, input_symbols, outputs_symbols, stacks) -> None
func : the logical function which will be represent as a stmt
"""
self.collect_input_variables(list(args))
self.collect_input_variables(list(kwargs.values()))
Expand All @@ -522,6 +557,10 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
convert_to_symbol(args),
convert_to_symbol(kwargs),
)

self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(args))
self.sir_ctx.TOS.set_symbol_meta_map(get_symbol_meta_map(kwargs))

log(3, f" inputs : {inputs_symbols}", "\n")

outputs = map_if(
Expand Down Expand Up @@ -564,7 +603,37 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
outputs, self, DummyTracker(list(args) + list(kwargs.values()))
)
else:
return None
return ConstantVariable.wrap_literal(None, self)

@staticmethod
def get_opcode_executor_stack():
# NOTE: only for debug.
# dependent on OpcodeExecutor.
from .opcode_executor import OpcodeExecutorBase

if len(OpcodeExecutorBase.call_stack) == 0:
# In test case, we can meet this senario.
return []
current_executor = OpcodeExecutorBase.call_stack[-1]
current_line = current_executor._current_line
filename = current_executor._code.co_filename
source_lines, start_line = inspect.getsourcelines(
current_executor._code
)
# TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph,
# We need to find a way to fix this.
line_idx = max(min(current_line - start_line, len(source_lines) - 1), 0)
code_line = source_lines[line_idx]
stack = []
stack.append(
' File "{}", line {}, in {}'.format(
filename,
current_line,
current_executor._code.co_name,
)
)
stack.append(f' {code_line}')
return stack

def _put_inner(self, vars: VariableBase):
"""
Expand Down
Loading