Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

import gc
import sys
import traceback
import types
from typing import List, Tuple
Expand Down Expand Up @@ -190,13 +189,6 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction:
Returns:
GuardedFunction | None: The translated code object and its guard function, or None if translation fails.
"""
if sys.version_info >= (3, 11):
for const in frame.f_code.co_consts:
if isinstance(const, types.CodeType) and const.co_name.startswith(
"<"
):
log(2, f"Found code object {const.co_name}, skip it\n")
return CustomCode(None, False), dummy_guard
simulator = OpcodeExecutor(frame, **kwargs)
try:
simulator.check_code_simulatable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,13 @@ def load(self, var):
restore_instr_names = [
instr.opname for instr in restore_instrs[:instr_idx]
]
# NOTE(SigureMo): Trailing KW_NAMES + PRECALL is no need to restore in Python 3.11+
if restore_instr_names[-2:] == ["KW_NAMES", "PRECALL"]:
restore_instrs = restore_instrs[:-2]
# NOTE(SigureMo): Trailing KW_NAMES or PRECALL is no need to restore in Python 3.11+
if restore_instr_names[-1:] == ["PRECALL"]:
restore_instrs = restore_instrs[:-1]
restore_instr_names = restore_instr_names[:-1]
if restore_instr_names[-1:] == ["KW_NAMES"]:
restore_instrs = restore_instrs[:-1]
restore_instr_names = restore_instr_names[:-1]

self.pycode_gen.extend_instrs(restore_instrs)
nop = self.pycode_gen._add_instr("NOP")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1535,10 +1535,7 @@ def gen_compute_in_break_with_name_store(self, restore_names, instr_idx):
instr_idx:
the index for branch 1 to find the boundary and copy origin opcode
"""
if (
self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get()
and sys.version_info < (3, 11)
):
if self._graph.sir_ctx.TOS.graph_size() < ENV_MIN_GRAPH_SIZE.get():
store_var_info = {}
for name in restore_names:
_var = self.get_var(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@

from paddle.jit.sot.utils import log, log_do

from ...utils import InnerError
from .instruction_utils import instrs_info
from .stack_analyse import StackAnalyser


def apply_instr_pass(instrs, code_options):
log(3, f"[Opcode Pass]: Original New Code {code_options['co_name']}:\n")
log_do(3, lambda: print(instrs_info(instrs)))
supported_passes = (remove_load_store_pass,)
supported_passes = (
remove_load_store_pass,
remove_duplicate_resume,
check_precall_followed_by_call,
)

for instr_pass in supported_passes:
instr_pass(instrs, code_options)
Expand Down Expand Up @@ -254,3 +259,14 @@ def remove_duplicate_resume(instrs, code_options):
return
for resume in resumes[1:]:
instrs.remove(resume)


def check_precall_followed_by_call(instrs, code_options):
"""
PRECALL should be followed by CALL, otherwise it will cause a segmentation fault
"""
for instr, next_instr in zip(instrs[:-1], instrs[1:]):
if instr.opname == "PRECALL" and next_instr.opname != "CALL":
raise InnerError(
f"PRECALL is not followed by CALL in {code_options['co_name']}"
)
7 changes: 1 addition & 6 deletions test/sot/test_05_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
# BUILD_MAP (new)
# BUILD_CONST_KEY_MAP (new)

import sys
import unittest

from test_case_base import TestCaseBase

import paddle
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils.envs import strict_mode_guard


@check_no_breakgraph
Expand Down Expand Up @@ -244,10 +242,7 @@ def test_construct(self):
self.assert_results(dict_construct_from_dict)
self.assert_results(dict_construct_from_list)
self.assert_results(dict_construct_from_tuple)
# Temporarily fallback for comprehension in python3.11
use_strict_mode = sys.version_info < (3, 11)
with strict_mode_guard(use_strict_mode):
self.assert_results(dict_construct_from_comprehension)
self.assert_results(dict_construct_from_comprehension)

def test_dict_noargs(self):
self.assert_results(dict_no_arguments)
Expand Down
6 changes: 1 addition & 5 deletions test/sot/test_12_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from __future__ import annotations

import sys
import unittest

from test_case_base import TestCaseBase
Expand Down Expand Up @@ -241,10 +240,7 @@ def run_list_comp(x):
class TestListComp(TestCaseBase):
def test_list_comp(self):
x = [paddle.randn([1, 4]), paddle.randn([1, 4])]
# Temporarily fallback for comprehension in python3.11
use_strict_mode = sys.version_info < (3, 11)
with strict_mode_guard(use_strict_mode):
self.assert_results(run_list_comp, x)
self.assert_results(run_list_comp, x)


def for_enumerate_cache(func_list, x):
Expand Down
16 changes: 6 additions & 10 deletions test/sot/test_builtin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import sys
import unittest
from typing import Iterable

Expand Down Expand Up @@ -104,15 +103,12 @@ def test_map(self):
self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3})

def test_map_comprehension(self):
# Temporarily fallback for comprehension in python3.11
use_strict_mode = sys.version_info < (3, 11)
with strict_mode_guard(use_strict_mode):
self.assert_results(test_map_list_comprehension, [1, 2, 3, 4])
self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4))
self.assert_results(test_map_range_comprehension, range(5))
self.assert_results(
test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3}
)
self.assert_results(test_map_list_comprehension, [1, 2, 3, 4])
self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4))
self.assert_results(test_map_range_comprehension, range(5))
self.assert_results(
test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3}
)

def test_map_with_breakgraph(self):
with strict_mode_guard(False):
Expand Down
33 changes: 11 additions & 22 deletions test/sot/test_listcomp.py → test/sot/test_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,28 @@
from test_case_base import TestCaseBase

import paddle
from paddle.jit.sot.utils.envs import min_graph_size_guard, strict_mode_guard
from paddle.jit.sot.opcode_translator.executor.dispatcher import Dispatcher
from paddle.jit.sot.utils.envs import min_graph_size_guard

# 8 will trigger the warmup in RESUME instruction and cause a segmentation fault
# RUN_N_TIMES should be larger than 8
RUN_N_TIMES = 20

builtin_fn = str.split
# Remove builtin_fn from Dispatcher to ensure that trigger a BreakGraph Error
if builtin_fn in Dispatcher.handlers:
del Dispatcher.handlers[builtin_fn]

def listcomp_fn():
print(1)
x = [i for i in range(10)] # noqa: C416
return x

def builtin_fn_with_breakgraph():
str.split("1,2,3,4,5", ",")

def genexpr_fn():
print(1)
x = (i for i in range(10))
return x


class TestListComp(TestCaseBase):
@strict_mode_guard(False)
@min_graph_size_guard(10)
def test_listcomp(self):
for _ in range(RUN_N_TIMES):
paddle.jit.to_static(listcomp_fn)()


class TestGenExpr(TestCaseBase):
@strict_mode_guard(False)
class TestSpecialization(TestCaseBase):
@min_graph_size_guard(10)
def test_genexpr(self):
def test_specialization(self):
for _ in range(RUN_N_TIMES):
paddle.jit.to_static(genexpr_fn)()
paddle.jit.to_static(builtin_fn_with_breakgraph)()


if __name__ == "__main__":
Expand Down