Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,16 @@ def call_function(self, /, *args, **kwargs):
)
assert isinstance(fn_var, VariableBase)
return fn_var(*args)
# If __bool__ and __len__ method are absent, inline bool calls return True.
# See https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L7463
elif magic_method.name == "__bool__" and not hasattr(
arg_type, "__len__"
):
return VariableFactory.from_value(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

咦?按我理解,这行一定会被执行才对呀,为什么覆盖率显示没有跑到这行呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

奇怪,调试是有的呢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明天上班我测试下,如果没问题的话可以找人豁免~

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我试了下,确实没跑到,但单独跑一个 case 是可以跑到的,说明是 cache 导致的,因此在每个 case 都清理了 cache 就可以了

True,
self.graph,
DummyTracker([self] + list(args) + list(kwargs.values())),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就目前来看 bool_flag 不需要单独作为一个变量,直接内联在使用处即可

不可以使用 ConstTracker,这个 True 是与 args 相关的,要用 DummyTracker


# Break graph if neither of the above conditions is met
arg_types = ", ".join([type(arg).__name__ for arg in args])
Expand Down
1 change: 1 addition & 0 deletions test/sot/skip_files_py312
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以及 test/sot/skip_files_py312python/paddle/jit/sot/opcode_translator/executor/variables/callable.py 怎么都 chmod 了?能不修改么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是我本地检查pre-commit代码,没有权限才改的,我调试下改回来

Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@
./test_specialization.py
./test_str_format.py
./test_tensor_dtype_in_guard.py
./test_builtin_bool.py
117 changes: 117 additions & 0 deletions test/sot/test_builtin_bool.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测在 Python 3.12 能跑过么?3.12 在 #61305 才初步支持,如果跑不过,在 test/sot/skip_files_py312 skip 一下

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测也加一些和 __len__ 组合的情况,允许 breakgraph(就是不加 check_no_breakgraph),但结果需要是对的

Copy link
Contributor Author

@diadestiny diadestiny Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问跑CI允许 breakgraph,是不是需要除了不加check_no_breakgraph, 但是要加上@strict_mode_guard(False);我本地调试如果设置STRICT_MODE=True,也不加@strict_mode_guard(False),单测还是会报错:FallbackError

Copy link
Member

@SigureMo SigureMo Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认的 STRICT_MODE 是不允许 Fallback,按理说是不会 fallback 的,这里为什么会发生 fallback 呢?

@strict_mode_guard(False);

如非特殊 case,不允许加 strict_mode_guard(False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
例如这个case, 如果按照原来的代码(不加上特判)来运行的话,
会在这里触发BreakGraphError,再导致这里触发FallbackError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以我本地单侧过不去的case是:如果一个只实现了__len__方法的类obj,在遇到if obj

class TestObjectWithLen:
    def __init__(self,list):
        self.list = list
    def __len__(self):
        return len(self.list)

也是不满足这个magic_method.name == "__bool__" and not hasattr(arg_type, "__len__")特判的,还是会导致和上述一样的BreakGraphError触发,接着触发FallbackError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

喔喔,这里有一个 BreakGraph -> Fallback 的转换,但没关系,这只影响 if obj 的单测 case,这样的 case 允许加 strict_mode_guard(False),但 bool(obj)operator.truth(obj) 应该不会 fallback,只会 breakgraph

另外,上面说的情况是存在 __len__ 的情况,不存在的 __len__ 的情况按理说不会 fallback 也不会 breakgraph

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以我本地单侧过不去的case是:如果一个只实现了__len__方法的类obj,在遇到if obj

没问题,我在发消息的时候还没有这条回复

Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
import unittest

from test_case_base import TestCaseBase

import paddle
from paddle.jit.sot.psdb import check_no_breakgraph
from paddle.jit.sot.utils import strict_mode_guard


class TestObject:
pass


class TestObjectWithBool:
def __bool__(self):
return False


class TestObjectWithLen:
def __init__(self, list):
self.list = list

def __len__(self):
return len(self.list)


class TestObjectWithBoolAndLen:
def __init__(self, list):
self.list = list

def __bool__(self):
return False

def __len__(self):
return len(self.list)


def call_bool_in_cond(obj):
if obj:
return True
else:
return False


def call_bool_by_bool(obj):
return bool(obj)


def call_bool_by_operator_truth(obj):
return operator.truth(obj)


class TestBuiltinBool(TestCaseBase):
def test_object_disallow_breakgraph(self): # disallow breakgraph
call_bool_in_cond_no_breakgraph = check_no_breakgraph(call_bool_in_cond)
call_bool_by_bool_no_breakgraph = check_no_breakgraph(call_bool_by_bool)
call_bool_by_operator_truth_no_breakgraph = check_no_breakgraph(
call_bool_by_operator_truth
)

obj = TestObject()
self.assert_results(call_bool_in_cond_no_breakgraph, obj)
self.assert_results(call_bool_by_bool_no_breakgraph, obj)
self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj)

obj = TestObjectWithBool()
self.assert_results(call_bool_in_cond_no_breakgraph, obj)
self.assert_results(call_bool_by_bool_no_breakgraph, obj)
self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj)

obj = TestObjectWithBoolAndLen([1, 2, 3])
self.assert_results(call_bool_in_cond_no_breakgraph, obj)
self.assert_results(call_bool_by_bool_no_breakgraph, obj)
self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj)

obj = TestObjectWithBoolAndLen([])
self.assert_results(call_bool_in_cond_no_breakgraph, obj)
self.assert_results(call_bool_by_bool_no_breakgraph, obj)
self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj)

layer = paddle.nn.Linear(10, 1)
self.assert_results(call_bool_in_cond_no_breakgraph, layer)
self.assert_results(call_bool_by_bool_no_breakgraph, layer)
self.assert_results(call_bool_by_operator_truth_no_breakgraph, layer)

def test_object_allow_breakgraph(self): # allow breakgraph
obj = TestObjectWithLen([1, 2, 3])
with strict_mode_guard(False):
self.assert_results(call_bool_in_cond, obj)

self.assert_results(call_bool_by_bool, obj)
self.assert_results(call_bool_by_operator_truth, obj)

obj = TestObjectWithLen([])
with strict_mode_guard(False):
self.assert_results(call_bool_in_cond, obj)

self.assert_results(call_bool_by_bool, obj)
self.assert_results(call_bool_by_operator_truth, obj)


if __name__ == "__main__":
unittest.main()