-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[SOT] support inline call bool magic function #61790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
7166134
56de455
9bb7f7a
b05385a
560c7c7
d359c71
7e9251a
7f6ab2f
599712d
68107eb
97e3069
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -646,6 +646,16 @@ def call_function(self, /, *args, **kwargs): | |
| ) | ||
| assert isinstance(fn_var, VariableBase) | ||
| return fn_var(*args) | ||
| # If __bool__ and __len__ method are absent, inline bool calls return True. | ||
| # See https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L7463 | ||
| elif magic_method.name == "__bool__" and not hasattr( | ||
| arg_type, "__len__" | ||
| ): | ||
| return VariableFactory.from_value( | ||
| True, | ||
| self.graph, | ||
| DummyTracker([self] + list(args) + list(kwargs.values())), | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 就目前来看 不可以使用 ConstTracker,这个 True 是与 args 相关的,要用 DummyTracker |
||
|
|
||
| # Break graph if neither of the above conditions is met | ||
| arg_types = ", ".join([type(arg).__name__ for arg in args]) | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 以及
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是我本地检查pre-commit代码,没有权限才改的,我调试下改回来 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,3 +40,4 @@ | |
| ./test_specialization.py | ||
| ./test_str_format.py | ||
| ./test_tensor_dtype_in_guard.py | ||
| ./test_builtin_bool.py | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个单测在 Python 3.12 能跑过么?3.12 在 #61305 才初步支持,如果跑不过,在 test/sot/skip_files_py312 skip 一下
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单测也加一些和
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请问跑CI允许 breakgraph,是不是需要除了不加check_no_breakgraph, 但是要加上@strict_mode_guard(False);我本地调试如果设置STRICT_MODE=True,也不加@strict_mode_guard(False),单测还是会报错:FallbackError
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 默认的 STRICT_MODE 是不允许 Fallback,按理说是不会 fallback 的,这里为什么会发生 fallback 呢?
如非特殊 case,不允许加
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 所以我本地单侧过不去的case是:如果一个只实现了__len__方法的类obj,在遇到 也是不满足这个
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 喔喔,这里有一个 BreakGraph -> Fallback 的转换,但没关系,这只影响 另外,上面说的情况是存在
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
没问题,我在发消息的时候还没有这条回复 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import operator | ||
| import unittest | ||
|
|
||
| from test_case_base import TestCaseBase | ||
|
|
||
| import paddle | ||
| from paddle.jit.sot.psdb import check_no_breakgraph | ||
| from paddle.jit.sot.utils import strict_mode_guard | ||
|
|
||
|
|
||
| class TestObject: | ||
| pass | ||
|
|
||
|
|
||
| class TestObjectWithBool: | ||
| def __bool__(self): | ||
| return False | ||
|
|
||
|
|
||
| class TestObjectWithLen: | ||
| def __init__(self, list): | ||
| self.list = list | ||
|
|
||
| def __len__(self): | ||
| return len(self.list) | ||
|
|
||
|
|
||
| class TestObjectWithBoolAndLen: | ||
| def __init__(self, list): | ||
| self.list = list | ||
|
|
||
| def __bool__(self): | ||
| return False | ||
|
|
||
| def __len__(self): | ||
| return len(self.list) | ||
|
|
||
|
|
||
| def call_bool_in_cond(obj): | ||
| if obj: | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
|
|
||
| def call_bool_by_bool(obj): | ||
| return bool(obj) | ||
|
|
||
|
|
||
| def call_bool_by_operator_truth(obj): | ||
| return operator.truth(obj) | ||
|
|
||
|
|
||
| class TestBuiltinBool(TestCaseBase): | ||
| def test_object_disallow_breakgraph(self): # disallow breakgraph | ||
| call_bool_in_cond_no_breakgraph = check_no_breakgraph(call_bool_in_cond) | ||
| call_bool_by_bool_no_breakgraph = check_no_breakgraph(call_bool_by_bool) | ||
| call_bool_by_operator_truth_no_breakgraph = check_no_breakgraph( | ||
| call_bool_by_operator_truth | ||
| ) | ||
|
|
||
| obj = TestObject() | ||
| self.assert_results(call_bool_in_cond_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_bool_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj) | ||
|
|
||
| obj = TestObjectWithBool() | ||
| self.assert_results(call_bool_in_cond_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_bool_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj) | ||
|
|
||
| obj = TestObjectWithBoolAndLen([1, 2, 3]) | ||
| self.assert_results(call_bool_in_cond_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_bool_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj) | ||
|
|
||
| obj = TestObjectWithBoolAndLen([]) | ||
| self.assert_results(call_bool_in_cond_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_bool_no_breakgraph, obj) | ||
| self.assert_results(call_bool_by_operator_truth_no_breakgraph, obj) | ||
|
|
||
| layer = paddle.nn.Linear(10, 1) | ||
| self.assert_results(call_bool_in_cond_no_breakgraph, layer) | ||
| self.assert_results(call_bool_by_bool_no_breakgraph, layer) | ||
| self.assert_results(call_bool_by_operator_truth_no_breakgraph, layer) | ||
|
|
||
| def test_object_allow_breakgraph(self): # allow breakgraph | ||
| obj = TestObjectWithLen([1, 2, 3]) | ||
| with strict_mode_guard(False): | ||
| self.assert_results(call_bool_in_cond, obj) | ||
|
|
||
| self.assert_results(call_bool_by_bool, obj) | ||
| self.assert_results(call_bool_by_operator_truth, obj) | ||
|
|
||
| obj = TestObjectWithLen([]) | ||
| with strict_mode_guard(False): | ||
| self.assert_results(call_bool_in_cond, obj) | ||
|
|
||
| self.assert_results(call_bool_by_bool, obj) | ||
| self.assert_results(call_bool_by_operator_truth, obj) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
咦?按我理解,这行一定会被执行才对呀,为什么覆盖率显示没有跑到这行呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
奇怪,调试是有的呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
明天上班我测试下,如果没问题的话可以找人豁免~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我试了下,确实没跑到,但单独跑一个 case 是可以跑到的,说明是 cache 导致的,因此在每个 case 都清理了 cache 就可以了