Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ class StringifyExpression:
def __init__(self, str_expr, sub_exprs, free_vars):
expr = str_expr.format(*[arg.expr for arg in sub_exprs])
self.expr = current_tmp_name_records().add_tmp_var(expr)
self.debug_expr = str_expr.format(
*[arg.debug_expr for arg in sub_exprs]
self.inlined_expr = str_expr.format(
*[arg.inlined_expr for arg in sub_exprs]
)
self.free_vars = free_vars

def __hash__(self):
if self.free_vars:
return hash((self.debug_expr, id(self)))
return hash((self.inlined_expr, id(self)))
else:
return hash(self.debug_expr)
return hash(self.inlined_expr)


def union_free_vars(*free_vars: dict[str, Any]):
Expand Down Expand Up @@ -90,7 +90,7 @@ def analyse_expressions(stringify_exprs, tmp_names):
func_result = ""
for str_expr in stringify_exprs:
func_result += str_expr.expr + " and "
lambda_string += str_expr.debug_expr + " and "
lambda_string += str_expr.inlined_expr + " and "
free_vars = union_free_vars(free_vars, str_expr.free_vars)

func_string += f" return {func_result[:-5]}"
Expand Down
17 changes: 10 additions & 7 deletions python/paddle/jit/sot/opcode_translator/executor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,22 @@ def trace_value_from_frame(self) -> StringifyExpression:

def match_expr(self, expr: str) -> bool:
"""
TODO(zrr1999)
Match the expression with the tracked variables.

Args:
expr (str): The expression to be matched.

Returns:
bool: True if the expression matches the tracked variables, False otherwise.
"""
raise NotImplementedError()
return self.trace_value_from_frame().inlined_expr == expr

def is_traceable(self) -> bool:
"""
Determine if all the tracked variables can be traced from the frame.

Returns:
bool, True if all tracked variables are traceable, False otherwise.
bool: True if all tracked variables are traceable, False otherwise.
"""
if self.changed:
return False
Expand Down Expand Up @@ -171,9 +177,6 @@ def gen_instructions(self, codegen: PyCodeGen) -> None:
def trace_value_from_frame(self) -> StringifyExpression:
return StringifyExpression(f"frame.f_locals['{self.name}']", [], {})

def match_expr(self, expr: str) -> bool:
return expr == f"frame.f_locals['{self.name}']"

def __repr__(self) -> str:
return f"LocalTracker(name={self.name})"

Expand Down Expand Up @@ -253,7 +256,7 @@ def gen_instructions(self, codegen: PyCodeGen):
def trace_value_from_frame(self):
value_str, value_free_vars = stringify_pyobject(self.value)
return StringifyExpression(
f"{value_str}", [], union_free_vars(value_free_vars)
value_str, [], union_free_vars(value_free_vars)
)

def __repr__(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
GetAttrTracker,
GetIterTracker,
GlobalTracker,
LocalTracker,
Tracker,
)
from .base import VariableBase, VariableFactory
Expand Down Expand Up @@ -180,16 +179,13 @@ def make_stringify_guard(self) -> list[StringifyExpression]:
ENV_SOT_ALLOW_DYNAMIC_SHAPE.get()
and isinstance(self.value, int)
and self.tracker.need_guard()
and isinstance(
self.tracker, LocalTracker
) # TODO(zrr1999): now only support local tracker
):
from ..executor_cache import OpcodeExecutorCache

frame_value_tracer = self.tracker.trace_value_from_frame()
symbolic_inputs = OpcodeExecutorCache().symbolic_inputs
symbolic_inputs.setdefault(frame_value_tracer.debug_expr, {})
symbolic_input = symbolic_inputs[frame_value_tracer.debug_expr]
symbolic_inputs.setdefault(frame_value_tracer.inlined_expr, {})
symbolic_input = symbolic_inputs[frame_value_tracer.inlined_expr]
symbolic_input.setdefault(self.value, 0)
symbolic_input[self.value] += 1

Expand Down Expand Up @@ -643,16 +639,16 @@ def _reconstruct(self, codegen: PyCodeGen):

@check_guard
def make_stringify_guard(self) -> list[StringifyExpression]:
assert ENV_SOT_ALLOW_DYNAMIC_SHAPE.get()
from ..executor_cache import OpcodeExecutorCache

frame_value_tracer = self.tracker.trace_value_from_frame()
symbolic_inputs = OpcodeExecutorCache().symbolic_inputs

assert frame_value_tracer.debug_expr in symbolic_inputs
assert ENV_SOT_ALLOW_DYNAMIC_SHAPE.get()
assert frame_value_tracer.inlined_expr in symbolic_inputs

# TODO(zrr1999): Once dynamic shape is used, there will be no new guards
symbolic_input = symbolic_inputs[frame_value_tracer.debug_expr]
symbolic_input = symbolic_inputs[frame_value_tracer.inlined_expr]
symbolic_input.setdefault(self.value, 0)
symbolic_input[self.value] += 1

Expand All @@ -672,13 +668,11 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
return
if not tracker.need_guard():
return
if not isinstance(tracker, LocalTracker):
# TODO(zrr1999): now only support local tracker
return

from ..executor_cache import OpcodeExecutorCache

symbolic_inputs = OpcodeExecutorCache().symbolic_inputs

for tracker_expr, symbolic_input in symbolic_inputs.items():
if tracker.match_expr(tracker_expr):
symbolic_input.setdefault(value, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_wrapped_items(self):
"ContainerVariable.get_wrapped_items do not implement"
)

def __len__(self):
def __len__(self) -> int:
raise FallbackError('ContainerVariable.__len__ do not implement')

def len(self):
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def match_name(self, name: str) -> bool:
return name.startswith(self.prefix)


_tmp_name_records = None


class TmpNameRecords:
def __init__(self):
self.name_generator = NameGenerator(prefix="_sot_tmp_")
Expand All @@ -91,6 +88,9 @@ def add_tmp_var(self, expr):
return tmp_name


_tmp_name_records = TmpNameRecords()


@contextmanager
def tmp_name_guard():
global _tmp_name_records
Expand Down
45 changes: 37 additions & 8 deletions test/sot/test_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,63 @@ def foo(x):
return x + s


def dynamic_int_input_func(x, n):
def dynamic_int_input_func1(x, n):
x = paddle.reshape(x, [n, -1])
return (x + n) * 2 - 1, (n + 1) * 2 - 1
return (x + n) * 2 - 1, (-n + 1) * 2 - 1


def dynamic_int_input_func2(x, n):
return x + n[1]


class TestOpcodeExecutorDynamicShapeCache(TestCaseBase):
def test_dynamic_int_input_cache_hit(self):
def test_dynamic_int_input_cache_hit_case1(self):
with with_allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
dynamic_int_input_func1, paddle.randn([3, 4, 5]), 1
)
self.assertEqual(ctx.translate_count, 1)
self.assert_results(
dynamic_int_input_func1, paddle.randn([3, 4, 5]), 2
)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(
dynamic_int_input_func1, paddle.randn([3, 4, 5]), 3
)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(
dynamic_int_input_func1, paddle.randn([3, 4, 5]), 4
)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(
dynamic_int_input_func1, paddle.randn([3, 4, 5]), 5
)
self.assertEqual(ctx.translate_count, 2)

def test_dynamic_int_input_cache_hit_case2(self):
with with_allow_dynamic_shape_guard(
True
), test_instruction_translator_cache_context() as ctx:
self.assert_results(
dynamic_int_input_func, paddle.randn([3, 4, 5]), 1
dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: 1}
)
self.assertEqual(ctx.translate_count, 1)
self.assert_results(
dynamic_int_input_func, paddle.randn([3, 4, 5]), 2
dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: 2}
)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(
dynamic_int_input_func, paddle.randn([3, 4, 5]), 3
dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: 3}
)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(
dynamic_int_input_func, paddle.randn([3, 4, 5]), 4
dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: 4}
)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(
dynamic_int_input_func, paddle.randn([3, 4, 5]), 5
dynamic_int_input_func2, paddle.randn([3, 4, 5]), {1: 5}
)
self.assertEqual(ctx.translate_count, 2)

Expand Down