-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Hackathon No.18] 为 Paddle 新增 frexp API #46401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
370cdc1
之前的pr合并了大量错误代码,重新提交一份
Zheng-Bicheng 5b0a5c9
之前的pr合并了大量错误代码,重新提交一份
Zheng-Bicheng c0af897
修正格式问题
Zheng-Bicheng 813fc26
改回原来的格式
Zheng-Bicheng 0780713
按照要求修改
Zheng-Bicheng 537419b
按照要求修改格式
Zheng-Bicheng 290f822
修复注释的问题
Zheng-Bicheng dfad468
更新格式
Zheng-Bicheng 7560c32
Merge remote-tracking branch 'upstream/develop' into develop
Zheng-Bicheng 81a2676
测试自动格式化
Zheng-Bicheng cdb6e48
修正英文注释
Zheng-Bicheng 5ec4a9c
fix docs build error
Ligoml e79e932
pre-commit
Zheng-Bicheng e4e0f94
for docs build
Ligoml d1c0935
for docs build
Ligoml 42b9b29
修复mantissa计算错误的bug
Zheng-Bicheng 2de0092
Merge branch 'develop' of https://github.com/Zheng-Bicheng/Paddle int…
Zheng-Bicheng 6f9922e
修复误判exponent可能存在负数,导致计算量增加的情况
Zheng-Bicheng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import unittest | ||
| import numpy as np | ||
| import paddle | ||
| import paddle.fluid | ||
|
|
||
|
|
||
| class TestFrexpAPI(unittest.TestCase): | ||
|
|
||
| def setUp(self): | ||
| np.random.seed(1024) | ||
| self.rtol = 1e-5 | ||
| self.atol = 1e-8 | ||
| self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ | ||
| else paddle.CPUPlace() | ||
| self.set_input() | ||
|
|
||
| def set_input(self): | ||
| self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') | ||
|
|
||
| # 静态图单测 | ||
| def test_static_api(self): | ||
Zheng-Bicheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # 开启静态图模式 | ||
| paddle.enable_static() | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| input_data = paddle.fluid.data('X', self.x_np.shape, | ||
| self.x_np.dtype) | ||
| out = paddle.frexp(input_data) | ||
| # 计算静态图结果 | ||
| exe = paddle.static.Executor(self.place) | ||
| res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) | ||
|
|
||
| out_ref = np.frexp(self.x_np) | ||
| # 对比静态图与 numpy 实现函数计算结果是否相同 | ||
| for n, p in zip(out_ref, res): | ||
| np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol) | ||
|
|
||
| # 动态图单测 | ||
| def test_dygraph_api(self): | ||
| # 关闭静态图模式 | ||
| paddle.disable_static(self.place) | ||
| input_num = paddle.to_tensor(self.x_np) | ||
| # 测试动态图 tensor.frexp 和 paddle.tensor.math.frexp 计算结果 | ||
| out1 = np.frexp(self.x_np) | ||
| out2 = paddle.frexp(input_num) | ||
| np.testing.assert_allclose(out1, out2, rtol=1e-05) | ||
|
|
||
| out1 = np.frexp(self.x_np) | ||
| out2 = input_num.frexp() | ||
| np.testing.assert_allclose(out1, out2, rtol=1e-05) | ||
| paddle.enable_static() | ||
|
|
||
|
|
||
| class TestSplitsFloat32Case1(TestFrexpAPI): | ||
| """ | ||
| Test num_or_sections which is an integer and data type is float32. | ||
| """ | ||
|
|
||
| def set_input(self): | ||
| self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float32') | ||
|
|
||
|
|
||
| class TestSplitsFloat64Case1(TestFrexpAPI): | ||
| """ | ||
| Test num_or_sections which is an integer and data type is float64. | ||
| """ | ||
|
|
||
| def set_input(self): | ||
| self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float64') | ||
|
|
||
|
|
||
| class TestSplitsFloat64Case2(TestFrexpAPI): | ||
| """ | ||
| Test num_or_sections which is an integer and data type is float64. | ||
| """ | ||
|
|
||
| def set_input(self): | ||
| self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float64') | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.