From 0b92112bf5b58b129fdd3f196166ae859bdd5b1f Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 15 Apr 2024 13:28:49 +0800 Subject: [PATCH 01/15] add hack10 part --- python/paddle/__init__.py | 6 + python/paddle/tensor/__init__.py | 6 + python/paddle/tensor/math.py | 141 ++++++++++++++++++++++++ test/legacy_test/test_isfinite_v2_op.py | 83 ++++++++++++++ test/legacy_test/test_isreal.py | 108 ++++++++++++++++++ 5 files changed, 344 insertions(+) create mode 100644 test/legacy_test/test_isreal.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ccf9d97c008c10..109e1b2bc3798e 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -427,6 +427,9 @@ isfinite, isinf, isnan, + isneginf, + isposinf, + isreal, kron, lcm, lcm_, @@ -717,6 +720,9 @@ 'to_tensor', 'gather_nd', 'isinf', + 'isneginf', + 'isposinf', + 'isreal', 'uniform', 'floor_divide', 'floor_divide_', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 936edb9c428fb9..2c3ce2764c9e2d 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -307,6 +307,9 @@ isfinite, isinf, isnan, + isneginf, + isposinf, + isreal, kron, lcm, lcm_, @@ -580,6 +583,9 @@ 'isfinite', 'isinf', 'isnan', + 'isneginf', + 'isposinf', + 'isreal', 'broadcast_shape', 'conj', 'neg', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9bde343f185fdc..6824e11146c1ea 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7728,3 +7728,144 @@ def signbit(x, name=None): x = paddle.sign(neg_zero_x) out = paddle.cast(x < 0, dtype='bool') return out + + +def isposinf(x, name=None): + r""" + Tests if each element of input is positive infinity or not. + + Args: + x (Tensor): The input Tensor. Must be one of the following types: float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): The output Tensor. Each element of output indicates whether the input element is positive infinity or not. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.set_device('cpu') + >>> x = paddle.to_tensor([-0., float('inf'), -2.1, -float('inf'), 2.5], dtype='float32') + >>> res = paddle.isposinf(x) + >>> print(res) + Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, + [False, True, False, False, False]) + + """ + if not isinstance(x, (paddle.Tensor, Variable)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + + check_variable_and_dtype( + x, + "x", + [ + 'float16', + 'float32', + 'float64', + 'int8', + 'int16', + 'int32', + 'int64', + 'uint8', + ], + "isposinf", + ) ## dtype is the intersection of dtypes supported by isinf and signbit + is_inf = paddle.isinf(x) + signbit = ~paddle.signbit(x) + return paddle.logical_and(is_inf, signbit) + + +def isneginf(x, name=None): + r""" + Tests if each element of input is negative infinity or not. + + Args: + x (Tensor): The input Tensor. Must be one of the following types: float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): The output Tensor. Each element of output indicates whether the input element is negative infinity or not. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.set_device('cpu') + >>> x = paddle.to_tensor([-0., float('inf'), -2.1, -float('inf'), 2.5], dtype='float32') + >>> res = paddle.isneginf(x) + >>> print(res) + Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, + [False, False, False, True, False]) + + """ + if not isinstance(x, (paddle.Tensor, Variable)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + + check_variable_and_dtype( + x, + "x", + [ + 'float16', + 'float32', + 'float64', + 'int8', + 'int16', + 'int32', + 'int64', + 'uint8', + ], + "isneginf", + ) + is_inf = paddle.isinf(x) + signbit = paddle.signbit(x) + return paddle.logical_and(is_inf, signbit) + + +def isreal(x, name=None): + r""" + Tests if each element of input is a real number or not. + + Args: + x (Tensor): The input Tensor. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): The output Tensor. Each element of output indicates whether the input element is a real number or not. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.set_device('cpu') + >>> x = paddle.to_tensor([-0., -2.1, 2.5], dtype='float32') + >>> res = paddle.isreal(x) + >>> print(res) + Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, + [True, True, True]) + + >>> x = paddle.to_tensor([(-0.+1j), (-2.1+0.2j), (2.5-3.1j)]) + >>> res = paddle.isreal(x) + >>> print(res) + Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, + [False, False, False]) + + >>> x = paddle.to_tensor([(-0.+1j), (-2.1+0j), (2.5-0j)]) + >>> res = paddle.isreal(x) + >>> print(res) + Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, + [False, True, True]) + """ + if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + dtype = x.dtype + is_real_dtype = not ( + dtype == core.VarDesc.VarType.COMPLEX64 + or dtype == core.VarDesc.VarType.COMPLEX128 + or dtype == core.DataType.COMPLEX64 + or dtype == core.DataType.COMPLEX128 + ) + if is_real_dtype: + return paddle.ones_like(x, dtype='bool') + + return paddle.equal(paddle.imag(x), 0) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index 36a6366e58b2c8..f11aaf24d4d4b7 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -135,6 +135,65 @@ def np_data_generator( }, ] +TEST_META_DATA2 = [ + { + 'low': 0.1, + 'high': 1, + 'np_shape': [8, 17, 5, 6, 7], + 'type': 'float16', + 'sv_list': [np.inf, -np.inf], + }, + { + 'low': 0.1, + 'high': 1, + 'np_shape': [11, 17], + 'type': 'float32', + 'sv_list': [-np.inf, np.inf], + }, + { + 'low': 0.1, + 'high': 1, + 'np_shape': [2, 3, 4, 5], + 'type': 'float64', + 'sv_list': [np.inf, -np.inf], + }, + { + 'low': 0, + 'high': 999, + 'np_shape': [132], + 'type': 'uint8', + 'sv_list': [-np.inf, np.inf], + }, + { + 'low': 0.1, + 'high': 1, + 'np_shape': [2, 3, 4, 5], + 'type': 'int8', + 'sv_list': [-np.inf, np.inf], + }, + { + 'low': 0, + 'high': 100, + 'np_shape': [11, 17, 10], + 'type': 'int16', + 'sv_list': [np.inf, -np.inf], + }, + { + 'low': 0, + 'high': 100, + 'np_shape': [11, 17, 10], + 'type': 'int32', + 'sv_list': [-np.inf, np.inf], + }, + { + 'low': 0, + 'high': 999, + 'np_shape': [132], + 'type': 'int64', + 'sv_list': [np.inf, -np.inf], + }, +] + def test(test_case, op_str, use_gpu=False, data_set=TEST_META_DATA): for meta_data in data_set: @@ -171,6 +230,12 @@ def test_finite(self): def test_inf_additional(self): test(self, 'isinf', data_set=TEST_META_DATA_ADDITIONAL) + def test_posinf(self): + test(self, 'isposinf', data_set=TEST_META_DATA2) + + def test_neginf(self): + test(self, 'isneginf', data_set=TEST_META_DATA2) + class TestCUDANormal(unittest.TestCase): def test_inf(self): @@ -185,6 +250,12 @@ def test_finite(self): def test_inf_additional(self): test(self, 'isinf', True, data_set=TEST_META_DATA_ADDITIONAL) + def test_posinf(self): + test(self, 'isposinf', True, data_set=TEST_META_DATA2) + + def test_neginf(self): + test(self, 'isneginf', True, data_set=TEST_META_DATA2) + class TestError(unittest.TestCase): @test_with_pir_api @@ -210,6 +281,18 @@ def test_isfinite_bad_x(): self.assertRaises(TypeError, test_isfinite_bad_x) + def test_isposinf_bad_x(): + x = [1, 2, 3] + result = paddle.isposinf(x) + + self.assertRaises(TypeError, test_isposinf_bad_x) + + def test_isneginf_bad_x(): + x = [1, 2, 3] + result = paddle.isneginf(x) + + self.assertRaises(TypeError, test_isneginf_bad_x) + if __name__ == '__main__': paddle.enable_static() diff --git a/test/legacy_test/test_isreal.py b/test/legacy_test/test_isreal.py new file mode 100644 index 00000000000000..19ff8e21d7d08e --- /dev/null +++ b/test/legacy_test/test_isreal.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base, static +from paddle.pir_utils import test_with_pir_api + +TEST_REAL_DATA = [ + np.array(1.0), + np.random.randint(-10, 10, (2, 3)), + np.random.randn(64, 32), +] +REAL_TYPE = [ + 'float16', + 'float32', + 'float64', + 'bool', + 'int16', + 'int32', + 'int64', + 'uint16', +] +TEST_COMPLEX_DATA = [ + np.array(1.0 + 2j), + np.array(1.0 + 0j), + np.array([[0.2 + 3j, 3 + 0j, -0.7 - 6j], [-0.4 + 0j, 3.5 - 10j, 2.5 + 0j]]), +] +COMPLEX_TYPE = ['complex64', 'complex128'] + + +def run_dygraph(data, type, use_gpu=False): + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + paddle.disable_static(place) + data = data.astype(type) + x = paddle.to_tensor(data) + return paddle.isreal(x) + + +def run_static(data, type, use_gpu=False): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = base.Executor(place) + with static.program_guard(main_program, startup_program): + data = data.astype(type) + x = paddle.static.data(name='x', shape=data.shape, dtype=type) + res = paddle.isreal(x) + static_result = exe.run(feed={'x': data}, fetch_list=[res]) + return static_result + + +def test(data_cases, type_cases, use_gpu=False): + for data in data_cases: + for type in type_cases: + dygraph_result = run_dygraph(data, type, use_gpu).numpy() + np_result = np.isreal(data.astype(type)) + np.testing.assert_equal(dygraph_result, np_result) + + @test_with_pir_api + def test_static_or_pir_mode(): + (static_result,) = run_static(data, type, use_gpu) + np.testing.assert_equal(static_result, np_result) + + test_static_or_pir_mode() + + +class TestIsRealError(unittest.TestCase): + def test_for_exception(self): + with self.assertRaises(TypeError): + paddle.isreal(np.array([1, 2])) + + +class TestIsReal(unittest.TestCase): + def test_for_real_tensor_without_gpu(self): + test(TEST_REAL_DATA, REAL_TYPE) + + def test_for_real_tensor_with_gpu(self): + test(TEST_REAL_DATA, REAL_TYPE, True) + + def test_for_complex_tensor_without_gpu(self): + test(TEST_COMPLEX_DATA, COMPLEX_TYPE) + + def test_for_complex_tensor_with_gpu(self): + test(TEST_COMPLEX_DATA, COMPLEX_TYPE, True) + + +if __name__ == '__main__': + unittest.main() From 2d1b227a30b6fcacbde44d12ac8d10cb8fd3e5fd Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 15 Apr 2024 15:40:40 +0800 Subject: [PATCH 02/15] update fp16 --- test/legacy_test/test_isfinite_v2_op.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index f11aaf24d4d4b7..cc3e8a21bc479c 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -136,13 +136,6 @@ def np_data_generator( ] TEST_META_DATA2 = [ - { - 'low': 0.1, - 'high': 1, - 'np_shape': [8, 17, 5, 6, 7], - 'type': 'float16', - 'sv_list': [np.inf, -np.inf], - }, { 'low': 0.1, 'high': 1, @@ -194,6 +187,16 @@ def np_data_generator( }, ] +TEST_META_DATA3 = [ + { + 'low': 0.1, + 'high': 1, + 'np_shape': [8, 17, 5, 6, 7], + 'type': 'float16', + 'sv_list': [np.inf, -np.inf], + }, +] + def test(test_case, op_str, use_gpu=False, data_set=TEST_META_DATA): for meta_data in data_set: @@ -252,9 +255,11 @@ def test_inf_additional(self): def test_posinf(self): test(self, 'isposinf', True, data_set=TEST_META_DATA2) + test(self, 'isposinf', True, data_set=TEST_META_DATA3) def test_neginf(self): test(self, 'isneginf', True, data_set=TEST_META_DATA2) + test(self, 'isposinf', True, data_set=TEST_META_DATA3) class TestError(unittest.TestCase): From 9df7bf7ef9220af3dc57f0a6037bab3fa651d930 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 15 Apr 2024 16:43:33 +0800 Subject: [PATCH 03/15] update --- python/paddle/tensor/math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6824e11146c1ea..55dc8cedd8e80e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7753,7 +7753,7 @@ def isposinf(x, name=None): [False, True, False, False, False]) """ - if not isinstance(x, (paddle.Tensor, Variable)): + if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)): raise TypeError(f"x must be tensor type, but got {type(x)}") check_variable_and_dtype( @@ -7799,7 +7799,7 @@ def isneginf(x, name=None): [False, False, False, True, False]) """ - if not isinstance(x, (paddle.Tensor, Variable)): + if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)): raise TypeError(f"x must be tensor type, but got {type(x)}") check_variable_and_dtype( From 38a2580982af68f073c6244a623fee62a71e8844 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 15 Apr 2024 17:39:43 +0800 Subject: [PATCH 04/15] fix fp16 test --- test/legacy_test/test_isfinite_v2_op.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index cc3e8a21bc479c..d98cd119608a9a 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -255,12 +255,23 @@ def test_inf_additional(self): def test_posinf(self): test(self, 'isposinf', True, data_set=TEST_META_DATA2) - test(self, 'isposinf', True, data_set=TEST_META_DATA3) def test_neginf(self): test(self, 'isneginf', True, data_set=TEST_META_DATA2) + + +@unittest.skipIf( + not base.core.is_compiled_with_cuda() + or not base.core.is_float16_supported(base.core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the float16", +) +class TestCUDAFP16(unittest.TestCase): + def test_posinf(self): test(self, 'isposinf', True, data_set=TEST_META_DATA3) + def test_neginf(self): + test(self, 'isneginf', True, data_set=TEST_META_DATA3) + class TestError(unittest.TestCase): @test_with_pir_api From bdfa3f319f48e8682f1e4784cbe59a07880b03bf Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 16 Apr 2024 12:25:27 +0800 Subject: [PATCH 05/15] update en docs --- python/paddle/tensor/math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 55dc8cedd8e80e..49661c9b8ca405 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7735,7 +7735,7 @@ def isposinf(x, name=None): Tests if each element of input is positive infinity or not. Args: - x (Tensor): The input Tensor. Must be one of the following types: float32, float64, int32, int64. + x (Tensor): The input Tensor. Must be one of the following types: float16, float32, float64, int8, int16, int32, int64, uint8. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -7781,7 +7781,7 @@ def isneginf(x, name=None): Tests if each element of input is negative infinity or not. Args: - x (Tensor): The input Tensor. Must be one of the following types: float32, float64, int32, int64. + x (Tensor): The input Tensor. Must be one of the following types: float16, float32, float64, int8, int16, int32, int64, uint8. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. Returns: From 3581c461ae0aad891029c38f65e4fe3ff3ed4170 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 19 Apr 2024 22:29:11 +0800 Subject: [PATCH 06/15] add bf16 --- python/paddle/tensor/math.py | 2 ++ test/legacy_test/test_isfinite_v2_op.py | 28 ++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 49661c9b8ca405..23060c6cd4a613 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7760,6 +7760,7 @@ def isposinf(x, name=None): x, "x", [ + 'bfloat16', 'float16', 'float32', 'float64', @@ -7806,6 +7807,7 @@ def isneginf(x, name=None): x, "x", [ + 'bfloat16', 'float16', 'float32', 'float64', diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index d98cd119608a9a..5c0e434cdb0955 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -60,7 +60,10 @@ def run_eager(x_np, op_str, use_gpu=True): def np_data_generator( low, high, np_shape, type, sv_list, op_str, *args, **kwargs ): - x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) + if type == 'bfloat16': + x_np = np.random.uniform(low, high, np_shape).astype(np.uint16) + else: + x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) # x_np.shape[0] >= len(sv_list) if type in ['float16', 'float32', 'float64']: for i, v in enumerate(sv_list): @@ -197,6 +200,16 @@ def np_data_generator( }, ] +TEST_META_DATA4 = [ + { + 'low': 0.1, + 'high': 1, + 'np_shape': [8, 17, 5, 6, 7], + 'type': 'bfloat16', + 'sv_list': [-np.inf, np.inf], + }, +] + def test(test_case, op_str, use_gpu=False, data_set=TEST_META_DATA): for meta_data in data_set: @@ -273,6 +286,19 @@ def test_neginf(self): test(self, 'isneginf', True, data_set=TEST_META_DATA3) +@unittest.skipIf( + not base.core.is_compiled_with_cuda() + or not base.core.is_bfloat16_supported(base.core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestCUDABFP16(unittest.TestCase): + def test_posinf(self): + test(self, 'isposinf', True, data_set=TEST_META_DATA4) + + def test_neginf(self): + test(self, 'isneginf', True, data_set=TEST_META_DATA4) + + class TestError(unittest.TestCase): @test_with_pir_api def test_bad_input(self): From 39aaa335e89e5ddf1f8c4ea0505138eb7f007fb5 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 19 Apr 2024 22:33:23 +0800 Subject: [PATCH 07/15] update docs --- python/paddle/tensor/math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 23060c6cd4a613..c400766b0fa0e6 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7735,7 +7735,7 @@ def isposinf(x, name=None): Tests if each element of input is positive infinity or not. Args: - x (Tensor): The input Tensor. Must be one of the following types: float16, float32, float64, int8, int16, int32, int64, uint8. + x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -7782,7 +7782,7 @@ def isneginf(x, name=None): Tests if each element of input is negative infinity or not. Args: - x (Tensor): The input Tensor. Must be one of the following types: float16, float32, float64, int8, int16, int32, int64, uint8. + x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. Returns: From 049b0f61d39e431345086b3aaf435d7c84f3dfa7 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 22 Apr 2024 10:38:07 +0800 Subject: [PATCH 08/15] fix test --- test/legacy_test/test_isfinite_v2_op.py | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index 5c0e434cdb0955..97702e4cfb41ae 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -60,10 +60,7 @@ def run_eager(x_np, op_str, use_gpu=True): def np_data_generator( low, high, np_shape, type, sv_list, op_str, *args, **kwargs ): - if type == 'bfloat16': - x_np = np.random.uniform(low, high, np_shape).astype(np.uint16) - else: - x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) + x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) # x_np.shape[0] >= len(sv_list) if type in ['float16', 'float32', 'float64']: for i, v in enumerate(sv_list): @@ -200,16 +197,6 @@ def np_data_generator( }, ] -TEST_META_DATA4 = [ - { - 'low': 0.1, - 'high': 1, - 'np_shape': [8, 17, 5, 6, 7], - 'type': 'bfloat16', - 'sv_list': [-np.inf, np.inf], - }, -] - def test(test_case, op_str, use_gpu=False, data_set=TEST_META_DATA): for meta_data in data_set: @@ -233,6 +220,18 @@ def test_static_or_pir_mode(): test_static_or_pir_mode() +def test_bf16(test_case, op_str): + x_np = np.array([float('inf'), -float('inf'), 2.0, 3.0]) + result_np = getattr(np, op_str)(x_np) + + place = paddle.CUDAPlace(0) + paddle.disable_static(place) + x = paddle.to_tensor(x_np, dtype='bfloat16') + dygraph_result = getattr(paddle, op_str)(x).numpy() + + test_case.assertTrue((dygraph_result == result_np).all()) + + class TestCPUNormal(unittest.TestCase): def test_inf(self): test(self, 'isinf') @@ -293,10 +292,10 @@ def test_neginf(self): ) class TestCUDABFP16(unittest.TestCase): def test_posinf(self): - test(self, 'isposinf', True, data_set=TEST_META_DATA4) + test_bf16(self, 'isposinf') def test_neginf(self): - test(self, 'isneginf', True, data_set=TEST_META_DATA4) + test_bf16(self, 'isneginf') class TestError(unittest.TestCase): From 7eb9a535972d7c3ba7e2eddf3ba13fe9e23aa514 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 22 Apr 2024 13:34:16 +0800 Subject: [PATCH 09/15] rerun ci --- python/paddle/tensor/math.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c400766b0fa0e6..b6f84fdf741306 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7736,7 +7736,7 @@ def isposinf(x, name=None): Args: x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8. - name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: out (Tensor): The output Tensor. Each element of output indicates whether the input element is positive infinity or not. @@ -7783,7 +7783,7 @@ def isneginf(x, name=None): Args: x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8. - name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: out (Tensor): The output Tensor. Each element of output indicates whether the input element is negative infinity or not. @@ -7830,7 +7830,7 @@ def isreal(x, name=None): Args: x (Tensor): The input Tensor. - name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: out (Tensor): The output Tensor. Each element of output indicates whether the input element is a real number or not. From ed2d7e758d4af91b3d1b495895a567aa88ccec4d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 15 Apr 2024 15:40:40 +0800 Subject: [PATCH 10/15] update fp16 --- test/legacy_test/test_isfinite_v2_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index 97702e4cfb41ae..2f240542e0a0bb 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -267,9 +267,11 @@ def test_inf_additional(self): def test_posinf(self): test(self, 'isposinf', True, data_set=TEST_META_DATA2) + test(self, 'isposinf', True, data_set=TEST_META_DATA3) def test_neginf(self): test(self, 'isneginf', True, data_set=TEST_META_DATA2) + test(self, 'isposinf', True, data_set=TEST_META_DATA3) @unittest.skipIf( From f5cef59211f495d93fd46f9aef3a77cc9a101073 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 15 Apr 2024 17:39:43 +0800 Subject: [PATCH 11/15] fix fp16 test --- test/legacy_test/test_isfinite_v2_op.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index 2f240542e0a0bb..5a06cdf4065efd 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -267,12 +267,23 @@ def test_inf_additional(self): def test_posinf(self): test(self, 'isposinf', True, data_set=TEST_META_DATA2) - test(self, 'isposinf', True, data_set=TEST_META_DATA3) def test_neginf(self): test(self, 'isneginf', True, data_set=TEST_META_DATA2) + + +@unittest.skipIf( + not base.core.is_compiled_with_cuda() + or not base.core.is_float16_supported(base.core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the float16", +) +class TestCUDAFP16(unittest.TestCase): + def test_posinf(self): test(self, 'isposinf', True, data_set=TEST_META_DATA3) + def test_neginf(self): + test(self, 'isneginf', True, data_set=TEST_META_DATA3) + @unittest.skipIf( not base.core.is_compiled_with_cuda() From 08e1f2f7185f548a43ed10b2e8b43f743e417265 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 19 Apr 2024 22:29:11 +0800 Subject: [PATCH 12/15] add bf16 --- test/legacy_test/test_isfinite_v2_op.py | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index 5a06cdf4065efd..fcb909579d7e0e 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -60,7 +60,10 @@ def run_eager(x_np, op_str, use_gpu=True): def np_data_generator( low, high, np_shape, type, sv_list, op_str, *args, **kwargs ): - x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) + if type == 'bfloat16': + x_np = np.random.uniform(low, high, np_shape).astype(np.uint16) + else: + x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) # x_np.shape[0] >= len(sv_list) if type in ['float16', 'float32', 'float64']: for i, v in enumerate(sv_list): @@ -197,6 +200,16 @@ def np_data_generator( }, ] +TEST_META_DATA4 = [ + { + 'low': 0.1, + 'high': 1, + 'np_shape': [8, 17, 5, 6, 7], + 'type': 'bfloat16', + 'sv_list': [-np.inf, np.inf], + }, +] + def test(test_case, op_str, use_gpu=False, data_set=TEST_META_DATA): for meta_data in data_set: @@ -285,19 +298,6 @@ def test_neginf(self): test(self, 'isneginf', True, data_set=TEST_META_DATA3) -@unittest.skipIf( - not base.core.is_compiled_with_cuda() - or not base.core.is_float16_supported(base.core.CUDAPlace(0)), - "core is not compiled with CUDA and not support the float16", -) -class TestCUDAFP16(unittest.TestCase): - def test_posinf(self): - test(self, 'isposinf', True, data_set=TEST_META_DATA3) - - def test_neginf(self): - test(self, 'isneginf', True, data_set=TEST_META_DATA3) - - @unittest.skipIf( not base.core.is_compiled_with_cuda() or not base.core.is_bfloat16_supported(base.core.CUDAPlace(0)), From 54a227eddb24d542ad740e86d8b0eaffede420f8 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 19 Apr 2024 22:33:23 +0800 Subject: [PATCH 13/15] update docs --- python/paddle/tensor/math.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b6f84fdf741306..ffb5efab4e6e4f 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7736,7 +7736,11 @@ def isposinf(x, name=None): Args: x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8. +<<<<<<< HEAD name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. +======= + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. +>>>>>>> e7eb8c0a37 (update docs) Returns: out (Tensor): The output Tensor. Each element of output indicates whether the input element is positive infinity or not. From d7f689e23f956df6d215181cae47285191055fe1 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 22 Apr 2024 10:38:07 +0800 Subject: [PATCH 14/15] fix test --- test/legacy_test/test_isfinite_v2_op.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/test/legacy_test/test_isfinite_v2_op.py b/test/legacy_test/test_isfinite_v2_op.py index fcb909579d7e0e..97702e4cfb41ae 100644 --- a/test/legacy_test/test_isfinite_v2_op.py +++ b/test/legacy_test/test_isfinite_v2_op.py @@ -60,10 +60,7 @@ def run_eager(x_np, op_str, use_gpu=True): def np_data_generator( low, high, np_shape, type, sv_list, op_str, *args, **kwargs ): - if type == 'bfloat16': - x_np = np.random.uniform(low, high, np_shape).astype(np.uint16) - else: - x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) + x_np = np.random.uniform(low, high, np_shape).astype(getattr(np, type)) # x_np.shape[0] >= len(sv_list) if type in ['float16', 'float32', 'float64']: for i, v in enumerate(sv_list): @@ -200,16 +197,6 @@ def np_data_generator( }, ] -TEST_META_DATA4 = [ - { - 'low': 0.1, - 'high': 1, - 'np_shape': [8, 17, 5, 6, 7], - 'type': 'bfloat16', - 'sv_list': [-np.inf, np.inf], - }, -] - def test(test_case, op_str, use_gpu=False, data_set=TEST_META_DATA): for meta_data in data_set: From 98824f0fc50667403c785e59a216f8704101e6e7 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 22 Apr 2024 13:34:16 +0800 Subject: [PATCH 15/15] rerun ci --- python/paddle/tensor/math.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ffb5efab4e6e4f..b6f84fdf741306 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7736,11 +7736,7 @@ def isposinf(x, name=None): Args: x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8. -<<<<<<< HEAD name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. -======= - name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. ->>>>>>> e7eb8c0a37 (update docs) Returns: out (Tensor): The output Tensor. Each element of output indicates whether the input element is positive infinity or not.