diff --git a/python/paddle/jit/sot/opcode_translator/executor/guard.py b/python/paddle/jit/sot/opcode_translator/executor/guard.py index cc9791f0214b38..7f2436675ed749 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/guard.py +++ b/python/paddle/jit/sot/opcode_translator/executor/guard.py @@ -18,6 +18,8 @@ import weakref from typing import TYPE_CHECKING, Any, Callable, TypeVar +import paddle + from ...profiler import EventGuard from ...utils import current_tmp_name_records, log, log_do @@ -171,3 +173,12 @@ def object_equal_stringify_guard(self) -> list[StringifyExpression]: ), ) ] + + +def stringify_pyobject(obj: object) -> tuple[str, dict[str, Any]]: + if isinstance(obj, paddle.core.VarDesc.VarType): + return f"paddle.core.VarDesc.VarType({obj.value})", {"paddle": paddle} + elif isinstance(obj, paddle.core.DataType): + return f"paddle.core.DataType({obj.value})", {"paddle": paddle} + # For builtin values + return f"{obj!r}", {} diff --git a/python/paddle/jit/sot/opcode_translator/executor/tracker.py b/python/paddle/jit/sot/opcode_translator/executor/tracker.py index d8c8c54b884804..7ec4c51f6598ca 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/tracker.py +++ b/python/paddle/jit/sot/opcode_translator/executor/tracker.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING from ...utils import InnerError, NameGenerator -from .guard import StringifyExpression, union_free_vars +from .guard import StringifyExpression, stringify_pyobject, union_free_vars if TYPE_CHECKING: from typing import Sequence @@ -242,7 +242,10 @@ def gen_instructions(self, codegen: PyCodeGen): codegen.gen_load_const(self.value) def trace_value_from_frame(self): - return StringifyExpression(f"{self.value!r}", [], {}) + value_str, value_free_vars = stringify_pyobject(self.value) + return StringifyExpression( + f"{value_str}", [], union_free_vars(value_free_vars) + ) def __repr__(self) -> str: return f"ConstTracker(value={self.value})" @@ -365,10 +368,11 @@ def gen_instructions(self, codegen: PyCodeGen): def trace_value_from_frame(self): container_tracer = self.container.tracker.trace_value_from_frame() + key_string, key_free_vars = stringify_pyobject(self.key) return StringifyExpression( - f"{{}}[{self.key!r}]", + f"{{}}[{key_string}]", [container_tracer], - union_free_vars(container_tracer.free_vars), + union_free_vars(container_tracer.free_vars, key_free_vars), ) def __repr__(self) -> str: diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index 1f06d61cf8dc30..ff982fd6d219ed 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -40,6 +40,7 @@ StringifyExpression, check_guard, object_equal_stringify_guard, + stringify_pyobject, union_free_vars, ) from ..mutable_data import MutableDictLikeData @@ -260,11 +261,15 @@ def make_stringify_guard(self) -> list[StringifyExpression]: tensor_value_tracer = ( self.tracker.obj.tracker.trace_value_from_frame() ) + dtype_str, dtype_free_vars = stringify_pyobject(self.value) return [ StringifyExpression( - f"str(MetaInfo.from_tensor({{}}).dtype) == '{str(self.value)}'", + f"MetaInfo.from_tensor({{}}).dtype == {dtype_str}", [tensor_value_tracer], - {"MetaInfo": MetaInfo}, + union_free_vars( + {"MetaInfo": MetaInfo}, + dtype_free_vars, + ), ) ] else: diff --git a/test/sot/test_case_base.py b/test/sot/test_case_base.py index bb6a13b6709b2d..bffdbe7ba83335 100644 --- a/test/sot/test_case_base.py +++ b/test/sot/test_case_base.py @@ -18,7 +18,6 @@ import copy import types import unittest -from functools import wraps import numpy as np @@ -103,43 +102,3 @@ def copy_fn(fn): sym_copied_fn.__globals__[key], paddle_fn.__globals__[key] ) self.assert_nest_match(sym_output, paddle_output) - - -# Some decorators for PIR test -def to_pir_test(fn): - # NOTE(SigureMo): This function should sync with test/dygraph_to_static/dygraph_to_static_utils.py - @wraps(fn) - def impl(*args, **kwargs): - in_dygraph_mode = paddle.in_dynamic_mode() - with paddle.pir_utils.IrGuard(): - if in_dygraph_mode: - paddle.disable_static() - ir_outs = fn(*args, **kwargs) - return ir_outs - - return impl - - -def run_in_pir_mode(fn): - @wraps(fn) - def impl(*args, **kwargs): - OpcodeExecutorCache().clear() - pir_fn = to_pir_test(fn) - return pir_fn(*args, **kwargs) - - return impl - - -def run_in_both_default_and_pir(fn): - @wraps(fn) - def impl(*args, **kwargs): - OpcodeExecutorCache().clear() - default_fn = fn - pir_fn = to_pir_test(fn) - default_outs = default_fn(*args, **kwargs) - OpcodeExecutorCache().clear() - # The out of test case should be None, which is not used. - _pir_outs = pir_fn(*args, **kwargs) - return default_outs - - return impl diff --git a/test/sot/test_dataclass.py b/test/sot/test_dataclass.py index a0e885ec99a263..a61d5511c09f18 100644 --- a/test/sot/test_dataclass.py +++ b/test/sot/test_dataclass.py @@ -15,10 +15,7 @@ import unittest from dataclasses import dataclass -from test_case_base import ( - TestCaseBase, - run_in_both_default_and_pir, -) +from test_case_base import TestCaseBase import paddle from paddle.jit.sot.utils import strict_mode_guard @@ -47,13 +44,11 @@ def return_dataclass_with_post_init(x): class TestDataclass(TestCaseBase): @strict_mode_guard(False) - @run_in_both_default_and_pir def test_dtype_reconstruct(self): x = paddle.to_tensor(1) self.assert_results(return_dataclass, x) @strict_mode_guard(False) - @run_in_both_default_and_pir def test_dtype_reconstruct_with_post_init(self): x = paddle.to_tensor(1) self.assert_results(return_dataclass_with_post_init, x) diff --git a/test/sot/test_dtype.py b/test/sot/test_dtype.py index d7dc252b9f6bf3..af4b44be4ca461 100644 --- a/test/sot/test_dtype.py +++ b/test/sot/test_dtype.py @@ -16,7 +16,6 @@ from test_case_base import ( TestCaseBase, - run_in_both_default_and_pir, test_instruction_translator_cache_context, ) @@ -42,8 +41,12 @@ def reconstruct_dtype(): return y +def dtype_guard(x, cast_map): + out = paddle.cast(x, cast_map[x.dtype]) + return out, out.dtype + + class TestTensorAstype(TestCaseBase): - @run_in_both_default_and_pir def test_tensor_astype(self): x = paddle.ones([2, 3], dtype="float32") y = paddle.ones([2, 3], dtype="int32") @@ -51,7 +54,6 @@ def test_tensor_astype(self): class TestTensorDtypeGuard(TestCaseBase): - @run_in_both_default_and_pir def test_tensor_dtype_guard(self): x = paddle.ones([2, 3], dtype="float32") y = paddle.ones([2, 3], dtype="int32") @@ -65,10 +67,16 @@ def test_tensor_dtype_guard(self): class TestDtypeReconstruct(TestCaseBase): - @run_in_both_default_and_pir def test_dtype_reconstruct(self): self.assert_results(reconstruct_dtype) +class TestDtypeGuard(TestCaseBase): + def test_dtype_guard(self): + dtype_map = {paddle.float32: paddle.float64} + x = paddle.ones([2, 3], dtype="float32") + self.assert_results(dtype_guard, x, dtype_map) + + if __name__ == "__main__": unittest.main()