Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/paddle/jit/sot/opcode_translator/executor/variables/callable.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,16 @@ def call_function(self, /, *args, **kwargs):
)
assert isinstance(fn_var, VariableBase)
return fn_var(*args)
# If __bool__ and __len__ method is absent, inline bool calls return True.
# See https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L7463
elif magic_method.name == "__bool__" and not hasattr(
arg_type, "__len__"
):
return VariableFactory.from_value(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

咦?按我理解,这行一定会被执行才对呀,为什么覆盖率显示没有跑到这行呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

奇怪,调试是有的呢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明天上班我测试下,如果没问题的话可以找人豁免~

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我试了下,确实没跑到,但单独跑一个 case 是可以跑到的,说明是 cache 导致的,因此在每个 case 都清理了 cache 就可以了

True,
self.graph,
DummyTracker([self] + list(args) + list(kwargs.values())),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就目前来看 bool_flag 不需要单独作为一个变量,直接内联在使用处即可

不可以使用 ConstTracker,这个 True 是与 args 相关的,要用 DummyTracker


# Break graph if neither of the above conditions is met
arg_types = ", ".join([type(arg).__name__ for arg in args])
Expand Down
21 changes: 1 addition & 20 deletions test/sot/skip_files_py312
100644 → 100755
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以及 test/sot/skip_files_py312python/paddle/jit/sot/opcode_translator/executor/variables/callable.py 怎么都 chmod 了?能不修改么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是我本地检查pre-commit代码,没有权限才改的,我调试下改回来

Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
./test_01_basic.py
./test_02_store_inplace.py
./test_03_tuple.py
./test_04_list.py
./test_05_dict.py
./test_06_call_function.py
./test_07_unpack.py
./test_08_rot.py
./test_09_f_string.py
./test_10_build_unpack.py
./test_11_jumps.py
./test_12_for_loop.py
./test_13_make_function.py
./test_14_operators.py
./test_15_slice.py
./test_16_paddle_api.py
./test_17_paddle_layer.py
./test_18_tensor_method.py
./test_19_closure.py
Expand All @@ -26,26 +21,14 @@
./test_builtin_map.py
./test_builtin_range.py
./test_builtin_zip.py
./test_call_ast.py
./test_call_object.py
./test_case_base.py
./test_constant_graph.py
./test_delete_fast.py
./test_dtype.py
./test_dup_top.py
./test_enumerate.py
./test_execution_base.py
./test_guard_outputs.py
./test_guard_user_defined_fn.py
./test_inplace_api.py
./test_instruction_translator_cache.py
./test_min_graph_size.py
./test_model_switch_training.py
./test_multiple_args.py
./test_numpy.py
./test_numpy_var_if.py
./test_output_restoration.py
./test_segment_linear.py
./test_side_effects.py
./test_simulate_initialize.py
./test_sir_rollback.py
Expand All @@ -57,6 +40,4 @@
./test_specialization.py
./test_str_format.py
./test_tensor_dtype_in_guard.py
./test_tensor_slice.py
./test_trace_list_arg.py
./test_unsupport_function.py
./test_builtin_bool.py
148 changes: 148 additions & 0 deletions test/sot/test_builtin_bool.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测在 Python 3.12 能跑过么?3.12 在 #61305 才初步支持,如果跑不过,在 test/sot/skip_files_py312 skip 一下

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测也加一些和 __len__ 组合的情况,允许 breakgraph(就是不加 check_no_breakgraph),但结果需要是对的

Copy link
Contributor Author

@diadestiny diadestiny Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问跑CI允许 breakgraph,是不是需要除了不加check_no_breakgraph, 但是要加上@strict_mode_guard(False);我本地调试如果设置STRICT_MODE=True,也不加@strict_mode_guard(False),单测还是会报错:FallbackError

Copy link
Member

@SigureMo SigureMo Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认的 STRICT_MODE 是不允许 Fallback,按理说是不会 fallback 的,这里为什么会发生 fallback 呢?

@strict_mode_guard(False);

如非特殊 case,不允许加 strict_mode_guard(False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
例如这个case, 如果按照原来的代码(不加上特判)来运行的话,
会在这里触发BreakGraphError,再导致这里触发FallbackError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以我本地单侧过不去的case是:如果一个只实现了__len__方法的类obj,在遇到if obj

class TestObjectWithLen:
    def __init__(self,list):
        self.list = list
    def __len__(self):
        return len(self.list)

也是不满足这个magic_method.name == "__bool__" and not hasattr(arg_type, "__len__")特判的,还是会导致和上述一样的BreakGraphError触发,接着触发FallbackError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

喔喔,这里有一个 BreakGraph -> Fallback 的转换,但没关系,这只影响 if obj 的单测 case,这样的 case 允许加 strict_mode_guard(False),但 bool(obj)operator.truth(obj) 应该不会 fallback,只会 breakgraph

另外,上面说的情况是存在 __len__ 的情况,不存在的 __len__ 的情况按理说不会 fallback 也不会 breakgraph

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以我本地单侧过不去的case是:如果一个只实现了__len__方法的类obj,在遇到if obj

没问题,我在发消息的时候还没有这条回复

Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
import unittest

from test_case_base import TestCaseBase

import paddle
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils import strict_mode_guard


class TestObject:
pass


class TestObjectWithBool:
def __bool__(self):
return False


class TestObjectWithLen:
def __init__(self, list):
self.list = list

def __len__(self):
return len(self.list)


class TestObjectWithBoolAndLen:
def __init__(self, list):
self.list = list

def __bool__(self):
return False

def __len__(self):
return len(self.list)


@check_no_breakgraph
def object_bool(obj):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数需要用 check_no_breakgraph 装饰确保没有打断,可参考其他单测的写法

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再增加 operator.truthbool 显式调用的 case

if obj:
return True
else:
return False


@strict_mode_guard(False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict_mode_guard 不是这样用的,可以参考其它单测,虽然这样可能生效

def object_bool_allow_breakgraph(obj):
if obj:
return True
else:
return False


@check_no_breakgraph
def test_bool(obj):
return bool(obj)


@check_no_breakgraph
def test_operator_truth(obj):
return operator.truth(obj)


def test_bool_allow_breakgraph(obj):
return bool(obj)


def test_operator_truth_allow_breakgraph(obj):
return operator.truth(obj)


class TestBuiltinBool(TestCaseBase):
def test_object(self):
object = TestObject()
self.assert_results(object_bool, object)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要用这种容易迷惑的名称,左边的 object_bool 明明是一个函数,可以改名 call_bool_in_cond,右边的 object 与 builtin 的 object 名字冲突,建议改为 obj

self.assert_results(object_bool, bool(object))
self.assert_results(object_bool, operator.truth(object))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这俩是在测啥?

self.assert_results(test_bool, object)
self.assert_results(test_operator_truth, object)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样,这里改成同风格的名字 call_bool_by_boolcall_bool_by_operator_truth


def test_object_with_bool(self):
object = TestObjectWithBool()
self.assert_results(object_bool, object)
self.assert_results(object_bool, bool(object))
self.assert_results(object_bool, operator.truth(object))
self.assert_results(test_bool, object)
self.assert_results(test_operator_truth, object)

def test_object_with_len(self):
object = TestObjectWithLen([1, 2, 3])
self.assert_results(object_bool_allow_breakgraph, object)
self.assert_results(object_bool_allow_breakgraph, bool(object))
self.assert_results(
object_bool_allow_breakgraph, operator.truth(object)
)
self.assert_results(test_bool_allow_breakgraph, object)
self.assert_results(test_operator_truth_allow_breakgraph, object)

object = TestObjectWithLen([])
self.assert_results(object_bool_allow_breakgraph, object)
self.assert_results(object_bool_allow_breakgraph, bool(object))
self.assert_results(
object_bool_allow_breakgraph, operator.truth(object)
)
self.assert_results(test_bool_allow_breakgraph, object)
self.assert_results(test_operator_truth_allow_breakgraph, object)

def test_object_with_bool_and_len(self):
object = TestObjectWithBoolAndLen([1, 2, 3])
self.assert_results(object_bool, object)
self.assert_results(object_bool, bool(object))
self.assert_results(object_bool, operator.truth(object))
self.assert_results(test_bool, object)
self.assert_results(test_operator_truth, object)

object = TestObjectWithBoolAndLen([])
self.assert_results(object_bool, object)
self.assert_results(object_bool, bool(object))
self.assert_results(object_bool, operator.truth(object))
self.assert_results(test_bool, object)
self.assert_results(test_operator_truth, object)

def test_layer(self):
layer = paddle.nn.Linear(10, 1)
self.assert_results(object_bool, layer)
self.assert_results(object_bool, bool(layer))
self.assert_results(object_bool, operator.truth(layer))
self.assert_results(test_bool, layer)
self.assert_results(test_operator_truth, layer)


if __name__ == "__main__":
unittest.main()