Skip to content

Commit 1e2af54

Browse files
[Hackathon No.18] 为 Paddle 新增 frexp API (#46401)
* 之前的pr合并了大量错误代码,重新提交一份 * 之前的pr合并了大量错误代码,重新提交一份 * 修正格式问题 * 改回原来的格式 * 按照要求修改 * 按照要求修改格式 * 修复注释的问题 * 更新格式 * 测试自动格式化 * 修正英文注释 * fix docs build error * pre-commit * for docs build * for docs build * 修复mantissa计算错误的bug * 修复误判exponent可能存在负数,导致计算量增加的情况 Co-authored-by: Ligoml <[email protected]>
1 parent 9a1855f commit 1e2af54

File tree

4 files changed

+147
-1
lines changed

4 files changed

+147
-1
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@
286286
from .tensor.math import frac # noqa: F401
287287
from .tensor.math import sgn # noqa: F401
288288
from .tensor.math import take # noqa: F401
289+
from .tensor.math import frexp # noqa: F401
289290

290291
from .tensor.random import bernoulli # noqa: F401
291292
from .tensor.random import poisson # noqa: F401
@@ -386,7 +387,6 @@
386387
os.environ.setdefault('runtime_include_dir', runtime_include_dir)
387388

388389
disable_static()
389-
390390
__all__ = [ # noqa
391391
'iinfo',
392392
'dtype',
@@ -667,4 +667,5 @@
667667
'sgn',
668668
'triu_indices',
669669
'take',
670+
'frexp',
670671
]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import unittest
15+
import numpy as np
16+
import paddle
17+
import paddle.fluid
18+
19+
20+
class TestFrexpAPI(unittest.TestCase):
21+
22+
def setUp(self):
23+
np.random.seed(1024)
24+
self.rtol = 1e-5
25+
self.atol = 1e-8
26+
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
27+
else paddle.CPUPlace()
28+
self.set_input()
29+
30+
def set_input(self):
31+
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
32+
33+
# 静态图单测
34+
def test_static_api(self):
35+
# 开启静态图模式
36+
paddle.enable_static()
37+
with paddle.static.program_guard(paddle.static.Program()):
38+
input_data = paddle.fluid.data('X', self.x_np.shape,
39+
self.x_np.dtype)
40+
out = paddle.frexp(input_data)
41+
# 计算静态图结果
42+
exe = paddle.static.Executor(self.place)
43+
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
44+
45+
out_ref = np.frexp(self.x_np)
46+
# 对比静态图与 numpy 实现函数计算结果是否相同
47+
for n, p in zip(out_ref, res):
48+
np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol)
49+
50+
# 动态图单测
51+
def test_dygraph_api(self):
52+
# 关闭静态图模式
53+
paddle.disable_static(self.place)
54+
input_num = paddle.to_tensor(self.x_np)
55+
# 测试动态图 tensor.frexp 和 paddle.tensor.math.frexp 计算结果
56+
out1 = np.frexp(self.x_np)
57+
out2 = paddle.frexp(input_num)
58+
np.testing.assert_allclose(out1, out2, rtol=1e-05)
59+
60+
out1 = np.frexp(self.x_np)
61+
out2 = input_num.frexp()
62+
np.testing.assert_allclose(out1, out2, rtol=1e-05)
63+
paddle.enable_static()
64+
65+
66+
class TestSplitsFloat32Case1(TestFrexpAPI):
67+
"""
68+
Test num_or_sections which is an integer and data type is float32.
69+
"""
70+
71+
def set_input(self):
72+
self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float32')
73+
74+
75+
class TestSplitsFloat64Case1(TestFrexpAPI):
76+
"""
77+
Test num_or_sections which is an integer and data type is float64.
78+
"""
79+
80+
def set_input(self):
81+
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float64')
82+
83+
84+
class TestSplitsFloat64Case2(TestFrexpAPI):
85+
"""
86+
Test num_or_sections which is an integer and data type is float64.
87+
"""
88+
89+
def set_input(self):
90+
self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float64')
91+
92+
93+
if __name__ == "__main__":
94+
unittest.main()

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@
239239
from .math import frac # noqa: F401
240240
from .math import sgn # noqa: F401
241241
from .math import take # noqa: F401
242+
from .math import frexp # noqa: F401
242243

243244
from .random import multinomial # noqa: F401
244245
from .random import standard_normal # noqa: F401
@@ -517,6 +518,7 @@
517518
'take',
518519
'bucketize',
519520
'sgn',
521+
'frexp',
520522
]
521523

522524
# this list used in math_op_patch.py for magic_method bind

python/paddle/tensor/math.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5108,3 +5108,52 @@ def take(x, index, mode='raise', name=None):
51085108
out = input_1d.index_select(index_1d).reshape(index.shape)
51095109

51105110
return out
5111+
5112+
5113+
def frexp(x, name=None):
5114+
"""
5115+
The function used to decompose a floating point number into mantissa and exponent.
5116+
5117+
Args:
5118+
x (Tensor): The input tensor, it's data type should be float32, float64.
5119+
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
5120+
Returns:
5121+
5122+
- mantissa (Tensor), A mantissa Tensor. The shape and data type of mantissa tensor and exponential tensor are
5123+
the same as those of input.
5124+
5125+
- exponent (Tensor), A exponent Tensor. The shape and data type of mantissa tensor and exponential tensor are
5126+
the same as those of input.
5127+
5128+
Examples:
5129+
.. code-block:: python
5130+
5131+
import paddle
5132+
5133+
x = paddle.to_tensor([[1, 2, 3, 4]], dtype="float32")
5134+
print(paddle.tensor.math.frexp(x))
5135+
# (Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,[[0.50000000, 0.50000000, 0.75000000, 0.50000000]]),
5136+
# Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,[[1., 2., 2., 3.]]))
5137+
"""
5138+
if x.dtype not in [paddle.float32, paddle.float64]:
5139+
raise TypeError(
5140+
"The data type of input must be one of ['float32', 'float64'], but got {}"
5141+
.format(x.dtype))
5142+
input_x = paddle.abs(x)
5143+
exponent = paddle.floor(paddle.log2(input_x))
5144+
exponent = paddle.where(paddle.isinf(exponent),
5145+
paddle.full_like(exponent, 0), exponent)
5146+
5147+
# 0填充
5148+
mantissa = paddle.divide(input_x, 2**exponent)
5149+
# 计算exponent
5150+
exponent = paddle.where((mantissa >= 1),
5151+
paddle.add(exponent, paddle.ones_like(exponent)),
5152+
exponent)
5153+
mantissa = paddle.where((mantissa >= 1),
5154+
paddle.divide(mantissa,
5155+
2**paddle.ones_like(exponent)),
5156+
mantissa)
5157+
5158+
mantissa = paddle.where((x < 0), mantissa * -1, mantissa)
5159+
return mantissa, exponent

0 commit comments

Comments
 (0)