diff --git a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py index 4efbcdc0472b83..2ed73db3767636 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py +++ b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py @@ -15,7 +15,6 @@ from __future__ import annotations import gc -import sys import traceback import types from typing import List, Tuple @@ -190,13 +189,6 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: Returns: GuardedFunction | None: The translated code object and its guard function, or None if translation fails. """ - if sys.version_info >= (3, 11): - for const in frame.f_code.co_consts: - if isinstance(const, types.CodeType) and const.co_name.startswith( - "<" - ): - log(2, f"Found code object {const.co_name}, skip it\n") - return CustomCode(None, False), dummy_guard simulator = OpcodeExecutor(frame, **kwargs) try: simulator.check_code_simulatable() diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index dfb5c4c5c9bdad..a2fb85734c7be4 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -260,9 +260,13 @@ def load(self, var): restore_instr_names = [ instr.opname for instr in restore_instrs[:instr_idx] ] - # NOTE(SigureMo): Trailing KW_NAMES + PRECALL is no need to restore in Python 3.11+ - if restore_instr_names[-2:] == ["KW_NAMES", "PRECALL"]: - restore_instrs = restore_instrs[:-2] + # NOTE(SigureMo): Trailing KW_NAMES or PRECALL is no need to restore in Python 3.11+ + if restore_instr_names[-1:] == ["PRECALL"]: + restore_instrs = restore_instrs[:-1] + restore_instr_names = restore_instr_names[:-1] + if restore_instr_names[-1:] == ["KW_NAMES"]: + restore_instrs = restore_instrs[:-1] + restore_instr_names = restore_instr_names[:-1] self.pycode_gen.extend_instrs(restore_instrs) nop = self.pycode_gen._add_instr("NOP") diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 302fa7a9e620f4..d14d7ef67dfacf 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -1535,10 +1535,7 @@ def gen_compute_in_break_with_name_store(self, restore_names, instr_idx): instr_idx: the index for branch 1 to find the boundary and copy origin opcode """ - if ( - self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get() - and sys.version_info < (3, 11) - ): + if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get(): store_var_info = {} for name in restore_names: _var = self.get_var(name) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py index 62cb8ba2a75d70..43fc5fa3606fa2 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py @@ -14,6 +14,7 @@ from paddle.jit.sot.utils import log, log_do +from ...utils import InnerError from .instruction_utils import instrs_info from .stack_analyse import StackAnalyser @@ -21,7 +22,11 @@ def apply_instr_pass(instrs, code_options): log(3, f"[Opcode Pass]: Original New Code {code_options['co_name']}:\n") log_do(3, lambda: print(instrs_info(instrs))) - supported_passes = (remove_load_store_pass,) + supported_passes = ( + remove_load_store_pass, + remove_duplicate_resume, + check_precall_followed_by_call, + ) for instr_pass in supported_passes: instr_pass(instrs, code_options) @@ -254,3 +259,14 @@ def remove_duplicate_resume(instrs, code_options): return for resume in resumes[1:]: instrs.remove(resume) + + +def check_precall_followed_by_call(instrs, code_options): + """ + PRECALL should be followed by CALL, otherwise it will cause a segmentation fault + """ + for instr, next_instr in zip(instrs[:-1], instrs[1:]): + if instr.opname == "PRECALL" and next_instr.opname != "CALL": + raise InnerError( + f"PRECALL is not followed by CALL in {code_options['co_name']}" + ) diff --git a/test/sot/test_05_dict.py b/test/sot/test_05_dict.py index 3ece5595cfe5ec..7014a717467984 100644 --- a/test/sot/test_05_dict.py +++ b/test/sot/test_05_dict.py @@ -16,14 +16,12 @@ # BUILD_MAP (new) # BUILD_CONST_KEY_MAP (new) -import sys import unittest from test_case_base import TestCaseBase import paddle from paddle.jit.sot.psdb import check_no_breakgraph -from paddle.jit.sot.utils.envs import strict_mode_guard @check_no_breakgraph @@ -244,10 +242,7 @@ def test_construct(self): self.assert_results(dict_construct_from_dict) self.assert_results(dict_construct_from_list) self.assert_results(dict_construct_from_tuple) - # Temporarily fallback for comprehension in python3.11 - use_strict_mode = sys.version_info < (3, 11) - with strict_mode_guard(use_strict_mode): - self.assert_results(dict_construct_from_comprehension) + self.assert_results(dict_construct_from_comprehension) def test_dict_noargs(self): self.assert_results(dict_no_arguments) diff --git a/test/sot/test_12_for_loop.py b/test/sot/test_12_for_loop.py index ff7f9b5128b935..3d3b59043504eb 100644 --- a/test/sot/test_12_for_loop.py +++ b/test/sot/test_12_for_loop.py @@ -17,7 +17,6 @@ from __future__ import annotations -import sys import unittest from test_case_base import TestCaseBase @@ -241,10 +240,7 @@ def run_list_comp(x): class TestListComp(TestCaseBase): def test_list_comp(self): x = [paddle.randn([1, 4]), paddle.randn([1, 4])] - # Temporarily fallback for comprehension in python3.11 - use_strict_mode = sys.version_info < (3, 11) - with strict_mode_guard(use_strict_mode): - self.assert_results(run_list_comp, x) + self.assert_results(run_list_comp, x) def for_enumerate_cache(func_list, x): diff --git a/test/sot/test_builtin_map.py b/test/sot/test_builtin_map.py index d56b607997b17b..f005ec10cdbe4b 100644 --- a/test/sot/test_builtin_map.py +++ b/test/sot/test_builtin_map.py @@ -14,7 +14,6 @@ from __future__ import annotations -import sys import unittest from typing import Iterable @@ -104,15 +103,12 @@ def test_map(self): self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3}) def test_map_comprehension(self): - # Temporarily fallback for comprehension in python3.11 - use_strict_mode = sys.version_info < (3, 11) - with strict_mode_guard(use_strict_mode): - self.assert_results(test_map_list_comprehension, [1, 2, 3, 4]) - self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4)) - self.assert_results(test_map_range_comprehension, range(5)) - self.assert_results( - test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3} - ) + self.assert_results(test_map_list_comprehension, [1, 2, 3, 4]) + self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4)) + self.assert_results(test_map_range_comprehension, range(5)) + self.assert_results( + test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3} + ) def test_map_with_breakgraph(self): with strict_mode_guard(False): diff --git a/test/sot/test_listcomp.py b/test/sot/test_specialization.py similarity index 61% rename from test/sot/test_listcomp.py rename to test/sot/test_specialization.py index f1723e3a9864ee..260ff8e70daa0c 100644 --- a/test/sot/test_listcomp.py +++ b/test/sot/test_specialization.py @@ -17,39 +17,28 @@ from test_case_base import TestCaseBase import paddle -from paddle.jit.sot.utils.envs import min_graph_size_guard, strict_mode_guard +from paddle.jit.sot.opcode_translator.executor.dispatcher import Dispatcher +from paddle.jit.sot.utils.envs import min_graph_size_guard # 8 will trigger the warmup in RESUME instruction and cause a segmentation fault # RUN_N_TIMES should be larger than 8 RUN_N_TIMES = 20 +builtin_fn = str.split +# Remove builtin_fn from Dispatcher to ensure that trigger a BreakGraph Error +if builtin_fn in Dispatcher.handlers: + del Dispatcher.handlers[builtin_fn] -def listcomp_fn(): - print(1) - x = [i for i in range(10)] # noqa: C416 - return x +def builtin_fn_with_breakgraph(): + str.split("1,2,3,4,5", ",") -def genexpr_fn(): - print(1) - x = (i for i in range(10)) - return x - -class TestListComp(TestCaseBase): - @strict_mode_guard(False) - @min_graph_size_guard(10) - def test_listcomp(self): - for _ in range(RUN_N_TIMES): - paddle.jit.to_static(listcomp_fn)() - - -class TestGenExpr(TestCaseBase): - @strict_mode_guard(False) +class TestSpecialization(TestCaseBase): @min_graph_size_guard(10) - def test_genexpr(self): + def test_specialization(self): for _ in range(RUN_N_TIMES): - paddle.jit.to_static(genexpr_fn)() + paddle.jit.to_static(builtin_fn_with_breakgraph)() if __name__ == "__main__":