-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[PIR] support normal and fix TestNoBackwardAPIStatic.test_normal UT
#62864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| import numpy as np | ||
|
|
||
| import paddle | ||
| from paddle.pir_utils import test_with_pir_api | ||
|
|
||
| np.random.seed(10) | ||
| paddle.seed(10) | ||
|
|
@@ -62,10 +63,11 @@ def static_api(self): | |
| ret_all_shape = copy.deepcopy(shape) | ||
| ret_all_shape.insert(0, self.repeat_num) | ||
| ret_all = np.zeros(ret_all_shape, self.dtype) | ||
| main_program = paddle.static.Program() | ||
| if isinstance(self.mean, np.ndarray) and isinstance( | ||
| self.std, np.ndarray | ||
| ): | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| with paddle.static.program_guard(main_program): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为啥改成同一个 Program 了啊
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里只可能跑到一种
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 喔,不过人家原来也没错啊,为啥要改
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我记得这玩意之前有点问题来着,不会切换还是啥的
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 喔 |
||
| mean = paddle.static.data( | ||
| 'Mean', self.mean.shape, self.mean.dtype | ||
| ) | ||
|
|
@@ -84,7 +86,7 @@ def static_api(self): | |
| ret_all[i] = ret[0] | ||
| return ret_all | ||
| elif isinstance(self.mean, np.ndarray): | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| with paddle.static.program_guard(main_program): | ||
| mean = paddle.static.data( | ||
| 'Mean', self.mean.shape, self.mean.dtype | ||
| ) | ||
|
|
@@ -96,7 +98,7 @@ def static_api(self): | |
| ret_all[i] = ret[0] | ||
| return ret_all | ||
| elif isinstance(self.std, np.ndarray): | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| with paddle.static.program_guard(main_program): | ||
| std = paddle.static.data('Std', self.std.shape, self.std.dtype) | ||
| out = paddle.normal(self.mean, std, self.shape) | ||
|
|
||
|
|
@@ -106,7 +108,7 @@ def static_api(self): | |
| ret_all[i] = ret[0] | ||
| return ret_all | ||
| else: | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| with paddle.static.program_guard(main_program): | ||
| out = paddle.normal(self.mean, self.std, self.shape) | ||
|
|
||
| exe = paddle.static.Executor(self.place) | ||
|
|
@@ -138,6 +140,7 @@ def dygraph_api(self): | |
| paddle.enable_static() | ||
| return ret_all | ||
|
|
||
| @test_with_pir_api | ||
| def test_api(self): | ||
| ret_static = self.static_api() | ||
| ret_dygraph = self.dygraph_api() | ||
|
|
@@ -185,6 +188,7 @@ def set_attrs(self): | |
|
|
||
|
|
||
| class TestNormalAlias(unittest.TestCase): | ||
| @test_with_pir_api | ||
| def test_alias(self): | ||
| paddle.disable_static() | ||
| shape = [1, 2, 3] | ||
|
|
@@ -195,8 +199,10 @@ def test_alias(self): | |
|
|
||
|
|
||
| class TestNormalErrors(unittest.TestCase): | ||
| @test_with_pir_api | ||
| def test_errors(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里现在是不能加嘛?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 补上了 |
||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| main_program = paddle.static.Program() | ||
| with paddle.static.program_guard(main_program): | ||
| mean = [1, 2, 3] | ||
| self.assertRaises(TypeError, paddle.normal, mean) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
744 行改回
if not in_dynamic_mode():可以么?这样 PIR 下少了很多检查,当然 Variable 相关位置需要加一下 Value