Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def add_doc_and_signature(func_name: str, docstr: str, func_def: str) -> None:
output Tensor. The result tensor will have one fewer dimension
than the `x` unless :attr:`keepdim` is true, default
value is False.
out (Tensor|None, optional): Output tensor. If provided in dynamic graph, the result will
be written to this tensor and also returned. The returned tensor and `out` share memory
and autograd meta. Default: None.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand Down Expand Up @@ -214,6 +217,9 @@ def amin(
output Tensor. The result tensor will have one fewer dimension
than the `x` unless :attr:`keepdim` is true, default
value is False.
out (Tensor|None, optional): Output tensor. If provided in dynamic graph, the result will
be written to this tensor and also returned. The returned tensor and `out` share memory
and autograd meta. Default: None.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand Down
20 changes: 15 additions & 5 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,10 @@ def _divide_with_axis(x, y, axis=-1, name=None):
return _elementwise_op(LayerHelper(op_type, **locals()))


def maximum(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
@ParamAliasDecorator({"x": ["input"], "y": ["other"]})
Copy link
Contributor

@zhwesky2010 zhwesky2010 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API-bechmark没过,试试改成函数式的装饰器 param_two_alias,这个性能好一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def maximum(
x: Tensor, y: Tensor, name: str | None = None, *, out: Tensor | None = None
) -> Tensor:
"""
Compare two tensors and returns a new tensor containing the element-wise maxima. The equation is:

Expand Down Expand Up @@ -1425,12 +1428,15 @@ def maximum(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
[5. , 3. , inf.])
"""
if in_dynamic_or_pir_mode():
return _C_ops.maximum(x, y)
return _C_ops.maximum(x, y, out=out)
else:
return _elementwise_op(LayerHelper('elementwise_max', **locals()))


def minimum(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
@ParamAliasDecorator({"x": ["input"], "y": ["other"]})
def minimum(
x: Tensor, y: Tensor, name: str | None = None, *, out: Tensor | None = None
) -> Tensor:
"""
Compare two tensors and return a new tensor containing the element-wise minima. The equation is:

Expand Down Expand Up @@ -1487,7 +1493,7 @@ def minimum(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
[ 1. , -inf., 5. ])
"""
if in_dynamic_or_pir_mode():
return _C_ops.minimum(x, y)
return _C_ops.minimum(x, y, out=out)
else:
return _elementwise_op(LayerHelper('elementwise_min', **locals()))

Expand Down Expand Up @@ -3124,11 +3130,14 @@ def __check_input(x, y):
return out


@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
def logsumexp(
x: Tensor,
axis: int | Sequence[int] | None = None,
keepdim: bool = False,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
r"""
Calculates the log of the sum of exponentials of ``x`` along ``axis`` .
Expand All @@ -3154,6 +3163,7 @@ def logsumexp(
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
out (Tensor|None, optional): The output tensor. Default: None.
name (str|None, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -3181,7 +3191,7 @@ def logsumexp(
reduce_all, axis = _get_reduce_axis(axis, x)

if in_dynamic_or_pir_mode():
return _C_ops.logsumexp(x, axis, keepdim, reduce_all)
return _C_ops.logsumexp(x, axis, keepdim, reduce_all, out=out)
else:
check_variable_and_dtype(
x,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,13 +1187,16 @@ def masked_select(x: Tensor, mask: Tensor, name: str | None = None) -> Tensor:
return out


@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
def topk(
x: Tensor,
k: int | Tensor,
axis: int | None = None,
largest: bool = True,
sorted: bool = True,
name: str | None = None,
*,
out: tuple[Tensor, Tensor] | None = None,
) -> tuple[Tensor, Tensor]:
"""
Return values and indices of the k largest or smallest at the optional axis.
Expand Down Expand Up @@ -1265,8 +1268,13 @@ def topk(
if in_dynamic_or_pir_mode():
if axis is None:
axis = -1
out, indices = _C_ops.topk(x, k, axis, largest, sorted)
return out, indices
values, indices = _C_ops.topk(x, k, axis, largest, sorted)
if out is not None:
out_values, out_indices = out
out_values = paddle.assign(values, output=out_values)
out_indices = paddle.assign(indices, output=out_indices)
return out_values, out_indices
return values, indices
else:
helper = LayerHelper("top_k_v2", **locals())
inputs = {"X": [x]}
Expand Down
62 changes: 62 additions & 0 deletions test/legacy_test/test_elementwise_max_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,67 @@ def setUp(self):
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}


class TestMaximumOutAPI(unittest.TestCase):
def test_out_in_dygraph(self):
paddle.disable_static()
np.random.seed(2024)
x = paddle.to_tensor(
np.random.randn(5, 7).astype('float32'), stop_gradient=False
)
# shift y to avoid ties for stable gradient routing
y = paddle.to_tensor(
(np.random.randn(5, 7) + 0.1).astype('float32'), stop_gradient=False
)

def run_case(case_type):
out_buf = paddle.zeros_like(x)
out_buf.stop_gradient = False

if case_type == 'return':
z = paddle.maximum(x, y)
elif case_type == 'input_out':
paddle.maximum(x, y, out=out_buf)
z = out_buf
elif case_type == 'both_return':
z = paddle.maximum(x, y, out=out_buf)
elif case_type == 'both_input_out':
_ = paddle.maximum(x, y, out=out_buf)
z = out_buf
else:
raise AssertionError

ref = paddle._C_ops.maximum(x, y)
np.testing.assert_allclose(
z.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6
)

loss = (z * 2).mean()
loss.backward()
return z.numpy(), x.grad.numpy(), y.grad.numpy()

z1, gx1, gy1 = run_case('return')
x.clear_gradient()
y.clear_gradient()
z2, gx2, gy2 = run_case('input_out')
x.clear_gradient()
y.clear_gradient()
z3, gx3, gy3 = run_case('both_return')
x.clear_gradient()
y.clear_gradient()
z4, gx4, gy4 = run_case('both_input_out')

np.testing.assert_allclose(z1, z2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z1, z3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z1, z4, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gx1, gx2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gx1, gx3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gx1, gx4, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gy1, gy2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gy1, gy3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gy1, gy4, rtol=1e-6, atol=1e-6)

paddle.enable_static()


if __name__ == '__main__':
unittest.main()
56 changes: 56 additions & 0 deletions test/legacy_test/test_elementwise_min_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,5 +485,61 @@ def setUp(self):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}


class TestMinimumOutAPI(unittest.TestCase):
def test_out_in_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[1, 2], [7, 8]]), dtype='float32', stop_gradient=False
)
y = paddle.to_tensor(
np.array([[3, 4], [5, 6]]), dtype='float32', stop_gradient=False
)

def run_case(case):
out_buf = paddle.zeros_like(x)
out_buf.stop_gradient = False
if case == 'return':
z = paddle.minimum(x, y)
elif case == 'input_out':
paddle.minimum(x, y, out=out_buf)
z = out_buf
elif case == 'both_return':
z = paddle.minimum(x, y, out=out_buf)
elif case == 'both_input_out':
_ = paddle.minimum(x, y, out=out_buf)
z = out_buf
else:
raise AssertionError
ref = paddle._C_ops.minimum(x, y)
np.testing.assert_allclose(
z.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6
)
(z.mean()).backward()
return z.numpy(), x.grad.numpy(), y.grad.numpy()

z1, gx1, gy1 = run_case('return')
x.clear_gradient()
y.clear_gradient()
z2, gx2, gy2 = run_case('input_out')
x.clear_gradient()
y.clear_gradient()
z3, gx3, gy3 = run_case('both_return')
x.clear_gradient()
y.clear_gradient()
z4, gx4, gy4 = run_case('both_input_out')

np.testing.assert_allclose(z1, z2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z1, z3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z1, z4, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gx1, gx2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gx1, gx3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gx1, gx4, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gy1, gy2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gy1, gy3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(gy1, gy4, rtol=1e-6, atol=1e-6)

paddle.enable_static()


if __name__ == '__main__':
unittest.main()
48 changes: 48 additions & 0 deletions test/legacy_test/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,5 +340,53 @@ def set_attrs(self):
self.axis = [1] # out return shape [2, 0]


class TestLogsumexpOutAPI(unittest.TestCase):
def test_out_in_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[-1.5, 0.0, 2.0], [3.0, 1.2, -2.4]]).astype('float32'),
stop_gradient=False,
)

def run_case(case):
out_buf = paddle.zeros([], dtype='float32')
out_buf.stop_gradient = False
if case == 'return':
y = paddle.logsumexp(x)
elif case == 'input_out':
paddle.logsumexp(x, out=out_buf)
y = out_buf
elif case == 'both_return':
y = paddle.logsumexp(x, out=out_buf)
elif case == 'both_input_out':
_ = paddle.logsumexp(x, out=out_buf)
y = out_buf
else:
raise AssertionError
ref = paddle._C_ops.logsumexp(x, [], False, True)
np.testing.assert_allclose(
y.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6
)
(y.mean()).backward()
return y.numpy(), x.grad.numpy()

y1, g1 = run_case('return')
x.clear_gradient()
y2, g2 = run_case('input_out')
x.clear_gradient()
y3, g3 = run_case('both_return')
x.clear_gradient()
y4, g4 = run_case('both_input_out')

np.testing.assert_allclose(y1, y2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(y1, y3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(y1, y4, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(g1, g2, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(g1, g3, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(g1, g4, rtol=1e-6, atol=1e-6)

paddle.enable_static()


if __name__ == '__main__':
unittest.main()
71 changes: 71 additions & 0 deletions test/legacy_test/test_max_min_amax_amin_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,5 +280,76 @@ def init_case(self):
self.keepdim = True


class TestAmaxAminOutAPI(unittest.TestCase):
def _run_api(self, api, x, case):
out_buf = paddle.zeros([], dtype=x.dtype)
out_buf.stop_gradient = False
if case == 'return':
y = api(x)
elif case == 'input_out':
api(x, out=out_buf)
y = out_buf
elif case == 'both_return':
y = api(x, out=out_buf)
elif case == 'both_input_out':
_ = api(x, out=out_buf)
y = out_buf
else:
raise AssertionError
return y

def test_amax_out_in_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[0.1, 0.9, 0.9, 0.9], [0.9, 0.9, 0.6, 0.7]]).astype(
'float64'
),
stop_gradient=False,
)
ref = paddle._C_ops.amax(x, None, False)
outs = []
grads = []
for case in ['return', 'input_out', 'both_return', 'both_input_out']:
y = self._run_api(paddle.amax, x, case)
np.testing.assert_allclose(
y.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6
)
loss = (y * 2).mean()
loss.backward()
outs.append(y.numpy())
grads.append(x.grad.numpy())
x.clear_gradient()
for i in range(1, 4):
np.testing.assert_allclose(outs[0], outs[i], rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(grads[0], grads[i], rtol=1e-6, atol=1e-6)
paddle.enable_static()

def test_amin_out_in_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(
np.array([[0.2, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.7]]).astype(
'float64'
),
stop_gradient=False,
)
ref = paddle._C_ops.amin(x, None, False)
outs = []
grads = []
for case in ['return', 'input_out', 'both_return', 'both_input_out']:
y = self._run_api(paddle.amin, x, case)
np.testing.assert_allclose(
y.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6
)
loss = (y * 2).mean()
loss.backward()
outs.append(y.numpy())
grads.append(x.grad.numpy())
x.clear_gradient()
for i in range(1, 4):
np.testing.assert_allclose(outs[0], outs[i], rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(grads[0], grads[i], rtol=1e-6, atol=1e-6)
paddle.enable_static()


if __name__ == '__main__':
unittest.main()
Loading
Loading