Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import weakref
from typing import TYPE_CHECKING, Any, Callable, TypeVar

import paddle

from ...profiler import EventGuard
from ...utils import current_tmp_name_records, log, log_do

Expand Down Expand Up @@ -171,3 +173,12 @@ def object_equal_stringify_guard(self) -> list[StringifyExpression]:
),
)
]


def stringify_pyobject(obj: object) -> tuple[str, dict[str, Any]]:
if isinstance(obj, paddle.core.VarDesc.VarType):
return f"paddle.core.VarDesc.VarType({obj.value})", {"paddle": paddle}
elif isinstance(obj, paddle.core.DataType):
return f"paddle.core.DataType({obj.value})", {"paddle": paddle}
# For builtin values
return f"{obj!r}", {}
12 changes: 8 additions & 4 deletions python/paddle/jit/sot/opcode_translator/executor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING

from ...utils import InnerError, NameGenerator
from .guard import StringifyExpression, union_free_vars
from .guard import StringifyExpression, stringify_pyobject, union_free_vars

if TYPE_CHECKING:
from typing import Sequence
Expand Down Expand Up @@ -242,7 +242,10 @@ def gen_instructions(self, codegen: PyCodeGen):
codegen.gen_load_const(self.value)

def trace_value_from_frame(self):
return StringifyExpression(f"{self.value!r}", [], {})
value_str, value_free_vars = stringify_pyobject(self.value)
return StringifyExpression(
f"{value_str}", [], union_free_vars(value_free_vars)
)

def __repr__(self) -> str:
return f"ConstTracker(value={self.value})"
Expand Down Expand Up @@ -365,10 +368,11 @@ def gen_instructions(self, codegen: PyCodeGen):

def trace_value_from_frame(self):
container_tracer = self.container.tracker.trace_value_from_frame()
key_string, key_free_vars = stringify_pyobject(self.key)
return StringifyExpression(
f"{{}}[{self.key!r}]",
f"{{}}[{key_string}]",
[container_tracer],
union_free_vars(container_tracer.free_vars),
union_free_vars(container_tracer.free_vars, key_free_vars),
)

def __repr__(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
StringifyExpression,
check_guard,
object_equal_stringify_guard,
stringify_pyobject,
union_free_vars,
)
from ..mutable_data import MutableDictLikeData
Expand Down Expand Up @@ -260,11 +261,15 @@ def make_stringify_guard(self) -> list[StringifyExpression]:
tensor_value_tracer = (
self.tracker.obj.tracker.trace_value_from_frame()
)
dtype_str, dtype_free_vars = stringify_pyobject(self.value)
return [
StringifyExpression(
f"str(MetaInfo.from_tensor({{}}).dtype) == '{str(self.value)}'",
f"MetaInfo.from_tensor({{}}).dtype == {dtype_str}",
[tensor_value_tracer],
{"MetaInfo": MetaInfo},
union_free_vars(
{"MetaInfo": MetaInfo},
dtype_free_vars,
),
)
]
else:
Expand Down
41 changes: 0 additions & 41 deletions test/sot/test_case_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import copy
import types
import unittest
from functools import wraps

import numpy as np

Expand Down Expand Up @@ -103,43 +102,3 @@ def copy_fn(fn):
sym_copied_fn.__globals__[key], paddle_fn.__globals__[key]
)
self.assert_nest_match(sym_output, paddle_output)


# Some decorators for PIR test
def to_pir_test(fn):
# NOTE(SigureMo): This function should sync with test/dygraph_to_static/dygraph_to_static_utils.py
@wraps(fn)
def impl(*args, **kwargs):
in_dygraph_mode = paddle.in_dynamic_mode()
with paddle.pir_utils.IrGuard():
if in_dygraph_mode:
paddle.disable_static()
ir_outs = fn(*args, **kwargs)
return ir_outs

return impl


def run_in_pir_mode(fn):
@wraps(fn)
def impl(*args, **kwargs):
OpcodeExecutorCache().clear()
pir_fn = to_pir_test(fn)
return pir_fn(*args, **kwargs)

return impl


def run_in_both_default_and_pir(fn):
@wraps(fn)
def impl(*args, **kwargs):
OpcodeExecutorCache().clear()
default_fn = fn
pir_fn = to_pir_test(fn)
default_outs = default_fn(*args, **kwargs)
OpcodeExecutorCache().clear()
# The out of test case should be None, which is not used.
_pir_outs = pir_fn(*args, **kwargs)
return default_outs

return impl
7 changes: 1 addition & 6 deletions test/sot/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import unittest
from dataclasses import dataclass

from test_case_base import (
TestCaseBase,
run_in_both_default_and_pir,
)
from test_case_base import TestCaseBase

import paddle
from paddle.jit.sot.utils import strict_mode_guard
Expand Down Expand Up @@ -47,13 +44,11 @@ def return_dataclass_with_post_init(x):

class TestDataclass(TestCaseBase):
@strict_mode_guard(False)
@run_in_both_default_and_pir
def test_dtype_reconstruct(self):
x = paddle.to_tensor(1)
self.assert_results(return_dataclass, x)

@strict_mode_guard(False)
@run_in_both_default_and_pir
def test_dtype_reconstruct_with_post_init(self):
x = paddle.to_tensor(1)
self.assert_results(return_dataclass_with_post_init, x)
Expand Down
16 changes: 12 additions & 4 deletions test/sot/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from test_case_base import (
TestCaseBase,
run_in_both_default_and_pir,
test_instruction_translator_cache_context,
)

Expand All @@ -42,16 +41,19 @@ def reconstruct_dtype():
return y


def dtype_guard(x, cast_map):
out = paddle.cast(x, cast_map[x.dtype])
return out, out.dtype


class TestTensorAstype(TestCaseBase):
@run_in_both_default_and_pir
def test_tensor_astype(self):
x = paddle.ones([2, 3], dtype="float32")
y = paddle.ones([2, 3], dtype="int32")
self.assert_results(tensor_astype, x, y)


class TestTensorDtypeGuard(TestCaseBase):
@run_in_both_default_and_pir
def test_tensor_dtype_guard(self):
x = paddle.ones([2, 3], dtype="float32")
y = paddle.ones([2, 3], dtype="int32")
Expand All @@ -65,10 +67,16 @@ def test_tensor_dtype_guard(self):


class TestDtypeReconstruct(TestCaseBase):
@run_in_both_default_and_pir
def test_dtype_reconstruct(self):
self.assert_results(reconstruct_dtype)


class TestDtypeGuard(TestCaseBase):
def test_dtype_guard(self):
dtype_map = {paddle.float32: paddle.float64}
x = paddle.ones([2, 3], dtype="float32")
self.assert_results(dtype_guard, x, dtype_map)


if __name__ == "__main__":
unittest.main()