Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@
isfinite,
isinf,
isnan,
isneginf,
isposinf,
isreal,
kron,
lcm,
lcm_,
Expand Down Expand Up @@ -717,6 +720,9 @@
'to_tensor',
'gather_nd',
'isinf',
'isneginf',
'isposinf',
'isreal',
'uniform',
'floor_divide',
'floor_divide_',
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@
isfinite,
isinf,
isnan,
isneginf,
isposinf,
isreal,
kron,
lcm,
lcm_,
Expand Down Expand Up @@ -580,6 +583,9 @@
'isfinite',
'isinf',
'isnan',
'isneginf',
'isposinf',
'isreal',
'broadcast_shape',
'conj',
'neg',
Expand Down
143 changes: 143 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -7728,3 +7728,146 @@ def signbit(x, name=None):
x = paddle.sign(neg_zero_x)
out = paddle.cast(x < 0, dtype='bool')
return out


def isposinf(x, name=None):
r"""
Tests if each element of input is positive infinity or not.

Args:
x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
out (Tensor): The output Tensor. Each element of output indicates whether the input element is positive infinity or not.

Examples:
.. code-block:: python

>>> import paddle
>>> paddle.set_device('cpu')
>>> x = paddle.to_tensor([-0., float('inf'), -2.1, -float('inf'), 2.5], dtype='float32')
>>> res = paddle.isposinf(x)
>>> print(res)
Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True,
[False, True, False, False, False])

"""
if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)):
raise TypeError(f"x must be tensor type, but got {type(x)}")

check_variable_and_dtype(
x,
"x",
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we support data type of bf16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.
bf16 type added.

'bfloat16',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
],
"isposinf",
) ## dtype is the intersection of dtypes supported by isinf and signbit
is_inf = paddle.isinf(x)
signbit = ~paddle.signbit(x)
return paddle.logical_and(is_inf, signbit)


def isneginf(x, name=None):
r"""
Tests if each element of input is negative infinity or not.

Args:
x (Tensor): The input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, int8, int16, int32, int64, uint8.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
out (Tensor): The output Tensor. Each element of output indicates whether the input element is negative infinity or not.

Examples:
.. code-block:: python

>>> import paddle
>>> paddle.set_device('cpu')
>>> x = paddle.to_tensor([-0., float('inf'), -2.1, -float('inf'), 2.5], dtype='float32')
>>> res = paddle.isneginf(x)
>>> print(res)
Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True,
[False, False, False, True, False])

"""
if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)):
raise TypeError(f"x must be tensor type, but got {type(x)}")

check_variable_and_dtype(
x,
"x",
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we support data type of bf16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.
bf16 type added.

'bfloat16',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
],
"isneginf",
)
is_inf = paddle.isinf(x)
signbit = paddle.signbit(x)
return paddle.logical_and(is_inf, signbit)


def isreal(x, name=None):
r"""
Tests if each element of input is a real number or not.

Args:
x (Tensor): The input Tensor.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
out (Tensor): The output Tensor. Each element of output indicates whether the input element is a real number or not.

Examples:
.. code-block:: python

>>> import paddle
>>> paddle.set_device('cpu')
>>> x = paddle.to_tensor([-0., -2.1, 2.5], dtype='float32')
>>> res = paddle.isreal(x)
>>> print(res)
Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True,
[True, True, True])

>>> x = paddle.to_tensor([(-0.+1j), (-2.1+0.2j), (2.5-3.1j)])
>>> res = paddle.isreal(x)
>>> print(res)
Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True,
[False, False, False])

>>> x = paddle.to_tensor([(-0.+1j), (-2.1+0j), (2.5-0j)])
>>> res = paddle.isreal(x)
>>> print(res)
Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True,
[False, True, True])
"""
if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)):
raise TypeError(f"x must be tensor type, but got {type(x)}")
dtype = x.dtype
is_real_dtype = not (
dtype == core.VarDesc.VarType.COMPLEX64
or dtype == core.VarDesc.VarType.COMPLEX128
or dtype == core.DataType.COMPLEX64
or dtype == core.DataType.COMPLEX128
)
if is_real_dtype:
return paddle.ones_like(x, dtype='bool')

return paddle.equal(paddle.imag(x), 0)
124 changes: 124 additions & 0 deletions test/legacy_test/test_isfinite_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,68 @@ def np_data_generator(
},
]

TEST_META_DATA2 = [
{
'low': 0.1,
'high': 1,
'np_shape': [11, 17],
'type': 'float32',
'sv_list': [-np.inf, np.inf],
},
{
'low': 0.1,
'high': 1,
'np_shape': [2, 3, 4, 5],
'type': 'float64',
'sv_list': [np.inf, -np.inf],
},
{
'low': 0,
'high': 999,
'np_shape': [132],
'type': 'uint8',
'sv_list': [-np.inf, np.inf],
},
{
'low': 0.1,
'high': 1,
'np_shape': [2, 3, 4, 5],
'type': 'int8',
'sv_list': [-np.inf, np.inf],
},
{
'low': 0,
'high': 100,
'np_shape': [11, 17, 10],
'type': 'int16',
'sv_list': [np.inf, -np.inf],
},
{
'low': 0,
'high': 100,
'np_shape': [11, 17, 10],
'type': 'int32',
'sv_list': [-np.inf, np.inf],
},
{
'low': 0,
'high': 999,
'np_shape': [132],
'type': 'int64',
'sv_list': [np.inf, -np.inf],
},
]

TEST_META_DATA3 = [
{
'low': 0.1,
'high': 1,
'np_shape': [8, 17, 5, 6, 7],
'type': 'float16',
'sv_list': [np.inf, -np.inf],
},
]


def test(test_case, op_str, use_gpu=False, data_set=TEST_META_DATA):
for meta_data in data_set:
Expand All @@ -158,6 +220,18 @@ def test_static_or_pir_mode():
test_static_or_pir_mode()


def test_bf16(test_case, op_str):
x_np = np.array([float('inf'), -float('inf'), 2.0, 3.0])
result_np = getattr(np, op_str)(x_np)

place = paddle.CUDAPlace(0)
paddle.disable_static(place)
x = paddle.to_tensor(x_np, dtype='bfloat16')
dygraph_result = getattr(paddle, op_str)(x).numpy()

test_case.assertTrue((dygraph_result == result_np).all())


class TestCPUNormal(unittest.TestCase):
def test_inf(self):
test(self, 'isinf')
Expand All @@ -171,6 +245,12 @@ def test_finite(self):
def test_inf_additional(self):
test(self, 'isinf', data_set=TEST_META_DATA_ADDITIONAL)

def test_posinf(self):
test(self, 'isposinf', data_set=TEST_META_DATA2)

def test_neginf(self):
test(self, 'isneginf', data_set=TEST_META_DATA2)


class TestCUDANormal(unittest.TestCase):
def test_inf(self):
Expand All @@ -185,6 +265,38 @@ def test_finite(self):
def test_inf_additional(self):
test(self, 'isinf', True, data_set=TEST_META_DATA_ADDITIONAL)

def test_posinf(self):
test(self, 'isposinf', True, data_set=TEST_META_DATA2)

def test_neginf(self):
test(self, 'isneginf', True, data_set=TEST_META_DATA2)


@unittest.skipIf(
not base.core.is_compiled_with_cuda()
or not base.core.is_float16_supported(base.core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the float16",
)
class TestCUDAFP16(unittest.TestCase):
def test_posinf(self):
test(self, 'isposinf', True, data_set=TEST_META_DATA3)

def test_neginf(self):
test(self, 'isneginf', True, data_set=TEST_META_DATA3)


@unittest.skipIf(
not base.core.is_compiled_with_cuda()
or not base.core.is_bfloat16_supported(base.core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestCUDABFP16(unittest.TestCase):
def test_posinf(self):
test_bf16(self, 'isposinf')

def test_neginf(self):
test_bf16(self, 'isneginf')


class TestError(unittest.TestCase):
@test_with_pir_api
Expand All @@ -210,6 +322,18 @@ def test_isfinite_bad_x():

self.assertRaises(TypeError, test_isfinite_bad_x)

def test_isposinf_bad_x():
x = [1, 2, 3]
result = paddle.isposinf(x)

self.assertRaises(TypeError, test_isposinf_bad_x)

def test_isneginf_bad_x():
x = [1, 2, 3]
result = paddle.isneginf(x)

self.assertRaises(TypeError, test_isneginf_bad_x)


if __name__ == '__main__':
paddle.enable_static()
Expand Down
Loading