diff --git a/python/paddle/jit/sot/opcode_translator/executor/guard.py b/python/paddle/jit/sot/opcode_translator/executor/guard.py index 7f2436675ed749..89891df02f1c09 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/guard.py +++ b/python/paddle/jit/sot/opcode_translator/executor/guard.py @@ -47,16 +47,16 @@ class StringifyExpression: def __init__(self, str_expr, sub_exprs, free_vars): expr = str_expr.format(*[arg.expr for arg in sub_exprs]) self.expr = current_tmp_name_records().add_tmp_var(expr) - self.debug_expr = str_expr.format( - *[arg.debug_expr for arg in sub_exprs] + self.inlined_expr = str_expr.format( + *[arg.inlined_expr for arg in sub_exprs] ) self.free_vars = free_vars def __hash__(self): if self.free_vars: - return hash((self.debug_expr, id(self))) + return hash((self.inlined_expr, id(self))) else: - return hash(self.debug_expr) + return hash(self.inlined_expr) def union_free_vars(*free_vars: dict[str, Any]): @@ -90,7 +90,7 @@ def analyse_expressions(stringify_exprs, tmp_names): func_result = "" for str_expr in stringify_exprs: func_result += str_expr.expr + " and " - lambda_string += str_expr.debug_expr + " and " + lambda_string += str_expr.inlined_expr + " and " free_vars = union_free_vars(free_vars, str_expr.free_vars) func_string += f" return {func_result[:-5]}" diff --git a/python/paddle/jit/sot/opcode_translator/executor/tracker.py b/python/paddle/jit/sot/opcode_translator/executor/tracker.py index 22a5bc1ae34da1..1f9a378a4f5f2b 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/tracker.py +++ b/python/paddle/jit/sot/opcode_translator/executor/tracker.py @@ -71,16 +71,22 @@ def trace_value_from_frame(self) -> StringifyExpression: def match_expr(self, expr: str) -> bool: """ - TODO(zrr1999) + Match the expression with the tracked variables. + + Args: + expr (str): The expression to be matched. + + Returns: + bool: True if the expression matches the tracked variables, False otherwise. """ - raise NotImplementedError() + return self.trace_value_from_frame().inlined_expr == expr def is_traceable(self) -> bool: """ Determine if all the tracked variables can be traced from the frame. Returns: - bool, True if all tracked variables are traceable, False otherwise. + bool: True if all tracked variables are traceable, False otherwise. """ if self.changed: return False @@ -171,9 +177,6 @@ def gen_instructions(self, codegen: PyCodeGen) -> None: def trace_value_from_frame(self) -> StringifyExpression: return StringifyExpression(f"frame.f_locals['{self.name}']", [], {}) - def match_expr(self, expr: str) -> bool: - return expr == f"frame.f_locals['{self.name}']" - def __repr__(self) -> str: return f"LocalTracker(name={self.name})" @@ -253,7 +256,7 @@ def gen_instructions(self, codegen: PyCodeGen): def trace_value_from_frame(self): value_str, value_free_vars = stringify_pyobject(self.value) return StringifyExpression( - f"{value_str}", [], union_free_vars(value_free_vars) + value_str, [], union_free_vars(value_free_vars) ) def __repr__(self) -> str: diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index 00c6ba6dd05997..c122e0bfb2435b 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -53,7 +53,6 @@ GetAttrTracker, GetIterTracker, GlobalTracker, - LocalTracker, Tracker, ) from .base import VariableBase, VariableFactory @@ -180,16 +179,13 @@ def make_stringify_guard(self) -> list[StringifyExpression]: ENV_SOT_ALLOW_DYNAMIC_SHAPE.get() and isinstance(self.value, int) and self.tracker.need_guard() - and isinstance( - self.tracker, LocalTracker - ) # TODO(zrr1999): now only support local tracker ): from ..executor_cache import OpcodeExecutorCache frame_value_tracer = self.tracker.trace_value_from_frame() symbolic_inputs = OpcodeExecutorCache().symbolic_inputs - symbolic_inputs.setdefault(frame_value_tracer.debug_expr, {}) - symbolic_input = symbolic_inputs[frame_value_tracer.debug_expr] + symbolic_inputs.setdefault(frame_value_tracer.inlined_expr, {}) + symbolic_input = symbolic_inputs[frame_value_tracer.inlined_expr] symbolic_input.setdefault(self.value, 0) symbolic_input[self.value] += 1 @@ -643,16 +639,16 @@ def _reconstruct(self, codegen: PyCodeGen): @check_guard def make_stringify_guard(self) -> list[StringifyExpression]: + assert ENV_SOT_ALLOW_DYNAMIC_SHAPE.get() from ..executor_cache import OpcodeExecutorCache frame_value_tracer = self.tracker.trace_value_from_frame() symbolic_inputs = OpcodeExecutorCache().symbolic_inputs - assert frame_value_tracer.debug_expr in symbolic_inputs - assert ENV_SOT_ALLOW_DYNAMIC_SHAPE.get() + assert frame_value_tracer.inlined_expr in symbolic_inputs # TODO(zrr1999): Once dynamic shape is used, there will be no new guards - symbolic_input = symbolic_inputs[frame_value_tracer.debug_expr] + symbolic_input = symbolic_inputs[frame_value_tracer.inlined_expr] symbolic_input.setdefault(self.value, 0) symbolic_input[self.value] += 1 @@ -672,13 +668,11 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): return if not tracker.need_guard(): return - if not isinstance(tracker, LocalTracker): - # TODO(zrr1999): now only support local tracker - return from ..executor_cache import OpcodeExecutorCache symbolic_inputs = OpcodeExecutorCache().symbolic_inputs + for tracker_expr, symbolic_input in symbolic_inputs.items(): if tracker.match_expr(tracker_expr): symbolic_input.setdefault(value, 0) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py index 6c642d0582258e..1d8429599a0e87 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py @@ -58,7 +58,7 @@ def get_wrapped_items(self): "ContainerVariable.get_wrapped_items do not implement" ) - def __len__(self): + def __len__(self) -> int: raise FallbackError('ContainerVariable.__len__ do not implement') def len(self): diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 383467744f3f62..b04175849aed23 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -71,9 +71,6 @@ def match_name(self, name: str) -> bool: return name.startswith(self.prefix) -_tmp_name_records = None - - class TmpNameRecords: def __init__(self): self.name_generator = NameGenerator(prefix="_sot_tmp_") @@ -91,6 +88,9 @@ def add_tmp_var(self, expr): return tmp_name +_tmp_name_records = TmpNameRecords() + + @contextmanager def tmp_name_guard(): global _tmp_name_records diff --git a/test/sot/test_dynamic_shape.py b/test/sot/test_sot_dynamic_shape.py similarity index 56% rename from test/sot/test_dynamic_shape.py rename to test/sot/test_sot_dynamic_shape.py index 5d58673140a8af..d0bb3623226c56 100644 --- a/test/sot/test_dynamic_shape.py +++ b/test/sot/test_sot_dynamic_shape.py @@ -30,36 +30,43 @@ def foo(x): return x + s -def dynamic_int_input_func(x, n): +def dynamic_int_input_func1(x, n): x = paddle.reshape(x, [n, -1]) - return (x + n) * 2 - 1, (n + 1) * 2 - 1 + return (x + n) * 2 - 1, (-n + 1) * 2 - 1 + + +def dynamic_int_input_func2(x, n): + return x + n[1] class TestOpcodeExecutorDynamicShapeCache(TestCaseBase): - def test_dynamic_int_input_cache_hit(self): + def test_dynamic_int_input_cache_hit_case1(self): with with_allow_dynamic_shape_guard( True ), test_instruction_translator_cache_context() as ctx: self.assert_results( - dynamic_int_input_func, paddle.randn([3, 4, 5]), 1 + dynamic_int_input_func1, paddle.randn([3, 4, 5]), 1 ) self.assertEqual(ctx.translate_count, 1) + for i in range(2, 6): + self.assert_results( + dynamic_int_input_func1, paddle.randn([3, 4, 5]), i + ) + self.assertEqual(ctx.translate_count, 2) + + def test_dynamic_int_input_cache_hit_case2(self): + with with_allow_dynamic_shape_guard( + True + ), test_instruction_translator_cache_context() as ctx: self.assert_results( - dynamic_int_input_func, paddle.randn([3, 4, 5]), 2 - ) - self.assertEqual(ctx.translate_count, 2) - self.assert_results( - dynamic_int_input_func, paddle.randn([3, 4, 5]), 3 - ) - self.assertEqual(ctx.translate_count, 2) - self.assert_results( - dynamic_int_input_func, paddle.randn([3, 4, 5]), 4 - ) - self.assertEqual(ctx.translate_count, 2) - self.assert_results( - dynamic_int_input_func, paddle.randn([3, 4, 5]), 5 + dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: 1} ) - self.assertEqual(ctx.translate_count, 2) + self.assertEqual(ctx.translate_count, 1) + for i in range(2, 6): + self.assert_results( + dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: i} + ) + self.assertEqual(ctx.translate_count, 2) if __name__ == '__main__':