Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/phi/ops/yaml/python_api_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
args_alias :
use_default_mapping : True

- op : greater_than
name : [paddle.greater_than, paddle.Tensor.greater_than]
args_alias :
use_default_mapping : True

- op : expand_as
name : [paddle.expand_as,paddle.Tensor.expand_as]
args_alias :
use_default_mapping : True
pre_process :
func : ExpandAsPreProcess(x,y,target_shape)

- op : logical_and
name : [paddle.logical_and, paddle.Tensor.logical_and]
args_alias:
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,12 +866,15 @@ def __dir__(self):
ger = outer
div = divide
div_ = divide_
eq = equal
gt = greater_than
swapdims = transpose
swapaxes = transpose


__all__ = [
'block_diag',
'gt',
'eq',
'iinfo',
'finfo',
'dtype',
Expand Down
39 changes: 38 additions & 1 deletion python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,44 @@ def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor
# shenwei

# zhouxin
add_doc_and_signature(
"greater_than",
"""
Returns the truth value of :math:`x > y` elementwise, which is equivalent function to the overloaded operator `>`.

Note:
The output has no gradient.

Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
Alias: ``input``.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
Alias: ``other``.
name (str|None, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
out (Tensor, optional): The output tensor. If provided, the result will be stored in this tensor.
Returns:
Tensor: The output shape is same as input :attr:`x`. The output data type is bool.

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([1, 2, 3])
>>> y = paddle.to_tensor([1, 3, 2])
>>> result1 = paddle.greater_than(x, y)
>>> print(result1)
Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True,
[False, False, True ])
""",
"""
def greater_than(
x: Tensor, y: Tensor, name: str | None = None, *, out: Tensor | None = None
) -> Tensor
""",
)

add_doc_and_signature(
"sin",
"""
Expand Down Expand Up @@ -1085,7 +1123,6 @@ def floor(
) -> Tensor
""",
)

# hehongyu

# lousiyu
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@
# API alias
div = divide
div_ = divide_
take_along_dim = take_along_axis
swapdims = transpose
swapaxes = transpose

Expand Down Expand Up @@ -829,6 +830,7 @@
'moveaxis',
'repeat_interleave',
'take_along_axis',
'take_along_dim',
'scatter_reduce',
'put_along_axis',
'scatter_add',
Expand Down
90 changes: 9 additions & 81 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle
from paddle import _C_ops
from paddle._C_ops import ( # noqa: F401
greater_than,
logical_and,
logical_not,
logical_or,
Expand Down Expand Up @@ -373,7 +374,10 @@ def allclose(
return out


def equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
@param_two_alias(["x", "input"], ["y", "other"])
def equal(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个按装饰器先实现吧,可以用这个param_two_alias

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

x: Tensor, y: Tensor, name: str | None = None, *, out: Tensor | None = None
) -> Tensor:
"""

This layer returns the truth value of :math:`x == y` elementwise.
Expand All @@ -383,9 +387,12 @@ def equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:

Args:
x (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
alias: ``input``
y (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
alias: ``other``
name (str|None, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
out (Tensor, optional): Output tensor. If provided, the result will be stored in this tensor.

Returns:
Tensor: output Tensor, it's shape is the same as the input's Tensor,
Expand Down Expand Up @@ -417,7 +424,7 @@ def equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
y = paddle.to_tensor(y)

if in_dynamic_or_pir_mode():
return _C_ops.equal(x, y)
return _C_ops.equal(x, y, out=out)
else:
check_variable_and_dtype(
x,
Expand Down Expand Up @@ -577,85 +584,6 @@ def greater_equal_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
return _C_ops.greater_equal_(x, y)


def greater_than(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
"""
Returns the truth value of :math:`x > y` elementwise, which is equivalent function to the overloaded operator `>`.

Note:
The output has no gradient.

Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
name (str|None, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The output shape is same as input :attr:`x`. The output data type is bool.

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([1, 2, 3])
>>> y = paddle.to_tensor([1, 3, 2])
>>> result1 = paddle.greater_than(x, y)
>>> print(result1)
Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True,
[False, False, True ])
"""
if in_dynamic_or_pir_mode():
return _C_ops.greater_than(x, y)
else:
check_variable_and_dtype(
x,
"x",
[
"bool",
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
"complex64",
"complex128",
],
"greater_than",
)
check_variable_and_dtype(
y,
"y",
[
"bool",
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
"complex64",
"complex128",
],
"greater_than",
)
helper = LayerHelper("greater_than", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True
helper.append_op(
type='greater_than',
inputs={'X': [x], 'Y': [y]},
outputs={'Out': [out]},
)
return out


@inplace_apis_in_dygraph_only
def greater_than_(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
r"""
Expand Down
13 changes: 10 additions & 3 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ def where(
x: Tensor | float | None = None,
y: Tensor | float | None = None,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
r"""
Return a Tensor of elements selected from either :attr:`x` or :attr:`y` according to corresponding elements of :attr:`condition`. Concretely,
Expand All @@ -691,6 +693,7 @@ def where(
y (Tensor|scalar|None, optional): A Tensor or scalar to choose when the condition is False with data type of bfloat16, float16, float32, float64, int32 or int64. Either both or neither of x and y should be given.
alias: ``other``.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
out (Tensor|None, optional): The output tensor. If set, the result will be stored to this tensor. Default is None.

Returns:
Tensor, A Tensor with the same shape as :attr:`condition` and same data type as :attr:`x` and :attr:`y`. If :attr:`x` and :attr:`y` have different data types, type promotion rules will be applied (see `Auto Type Promotion <https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/advanced/auto_type_promotion_en.html#introduction-to-data-type-promotion>`_).
Expand Down Expand Up @@ -721,7 +724,7 @@ def where(
y = paddle.to_tensor(y)

if x is None and y is None:
return nonzero(condition, as_tuple=True)
return nonzero(condition, as_tuple=True, out=out)

if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")
Expand Down Expand Up @@ -758,7 +761,9 @@ def where(
if y_shape != broadcast_shape:
broadcast_y = paddle.broadcast_to(broadcast_y, broadcast_shape)

return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
return _C_ops.where(
broadcast_condition, broadcast_x, broadcast_y, out=out
)

else:
# for PIR and old IR
Expand All @@ -781,7 +786,9 @@ def where(
broadcast_condition = paddle.cast(broadcast_condition, 'bool')

if in_pir_mode():
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
return _C_ops.where(
broadcast_condition, broadcast_x, broadcast_y, out=out
)
else:
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype(
Expand Down
21 changes: 21 additions & 0 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,27 @@ def test_place_2(self):
self.assertEqual((result.numpy() == np.array([False])).all(), True)


class TestCompareOut(unittest.TestCase):
def setUp(self) -> None:
self.shape = [2, 3, 4, 5]
self.apis = [paddle.eq, paddle.gt]
self.np_apis = [np.equal, np.greater]
self.input = np.random.rand(*self.shape).astype(np.float32)
self.other = np.random.rand(*self.shape).astype(np.float32)
self.other[0, 0, 3, 0] = self.input[0, 0, 3, 0]

def test_dygraph(self):
paddle.disable_static()
for api, np_api in zip(self.apis, self.np_apis):
x = paddle.to_tensor(self.input)
y = paddle.to_tensor(self.other)
out_holder = paddle.zeros_like(x)
api(x, y, out=out_holder)
np.testing.assert_allclose(
out_holder.numpy(), np_api(self.input, self.other)
)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
61 changes: 61 additions & 0 deletions test/legacy_test/test_take_along_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,66 @@ def test_take_along_dim(self):
)


class TestTensorTakeAlongAxisParamDecorator(unittest.TestCase):
def setUp(self):
paddle.disable_static()

self.input_shape = [2, 3, 4]
self.axis = 1
self.out_shape = [2, 2, 4]

self.x_np = np.random.rand(*self.input_shape).astype(np.float32)

self.indices_np = np.random.randint(
0, self.input_shape[self.axis], size=self.out_shape
).astype('int64')

self.method_names = [
'take_along_dim',
'take_along_axis',
]

self.test_types = ["kwargs"]

def do_test(self, method_name, test_type):
x = paddle.to_tensor(self.x_np, stop_gradient=False)
indices = paddle.to_tensor(self.indices_np)
out_tensor = paddle.empty(self.out_shape, dtype='float32')
out_tensor.stop_gradient = False

api_to_call = getattr(x, method_name)

if test_type == 'raw':
result = api_to_call(indices, self.axis)
elif test_type == 'kwargs':
result = api_to_call(indices=indices, axis=self.axis)
else:
raise ValueError(f"Unknown test type: {test_type}")

result.mean().backward()

return result, x.grad

def test_tensor_methods(self):
for method in self.method_names:
out_std, grad_std = self.do_test(method, 'raw')

for test_type in self.test_types:
with self.subTest(method=method, type=test_type):
out, grad = self.do_test(method, test_type)

np.testing.assert_allclose(
out.numpy(),
out_std.numpy(),
rtol=1e-20,
)

np.testing.assert_allclose(
grad.numpy(),
grad_std.numpy(),
rtol=1e-20,
)


if __name__ == "__main__":
unittest.main()
Loading