Skip to content

Commit 94dc71c

Browse files
authored
[API Compatibility] add out for topk (#74887)
* update * fix * update * update * fix * fix * fix docs * restore sqrt * fix * fix * fix * revert * update * update * update * revert minimum
1 parent 1b1cf09 commit 94dc71c

File tree

4 files changed

+152
-2
lines changed

4 files changed

+152
-2
lines changed

python/paddle/_paddle_docs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def add_doc_and_signature(func_name: str, docstr: str, func_def: str) -> None:
8080
output Tensor. The result tensor will have one fewer dimension
8181
than the `x` unless :attr:`keepdim` is true, default
8282
value is False.
83+
out (Tensor|None, optional): Output tensor. If provided in dynamic graph, the result will
84+
be written to this tensor and also returned. The returned tensor and `out` share memory
85+
and autograd meta. Default: None.
8386
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
8487
8588
Returns:
@@ -216,6 +219,9 @@ def amin(
216219
output Tensor. The result tensor will have one fewer dimension
217220
than the `x` unless :attr:`keepdim` is true, default
218221
value is False.
222+
out (Tensor|None, optional): Output tensor. If provided in dynamic graph, the result will
223+
be written to this tensor and also returned. The returned tensor and `out` share memory
224+
and autograd meta. Default: None.
219225
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
220226
221227
Returns:

python/paddle/tensor/search.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ParamAliasDecorator,
2727
index_select_decorator,
2828
param_one_alias,
29+
param_two_alias,
2930
)
3031
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
3132

@@ -1042,13 +1043,16 @@ def masked_select(x: Tensor, mask: Tensor, name: str | None = None) -> Tensor:
10421043
return out
10431044

10441045

1046+
@param_two_alias(["x", "input"], ["axis", "dim"])
10451047
def topk(
10461048
x: Tensor,
10471049
k: int | Tensor,
10481050
axis: int | None = None,
10491051
largest: bool = True,
10501052
sorted: bool = True,
10511053
name: str | None = None,
1054+
*,
1055+
out: tuple[Tensor, Tensor] | None = None,
10521056
) -> tuple[Tensor, Tensor]:
10531057
"""
10541058
Return values and indices of the k largest or smallest at the optional axis.
@@ -1120,8 +1124,13 @@ def topk(
11201124
if in_dynamic_or_pir_mode():
11211125
if axis is None:
11221126
axis = -1
1123-
out, indices = _C_ops.topk(x, k, axis, largest, sorted)
1124-
return out, indices
1127+
values, indices = _C_ops.topk(x, k, axis, largest, sorted)
1128+
if out is not None:
1129+
out_values, out_indices = out
1130+
out_values = paddle.assign(values, output=out_values)
1131+
out_indices = paddle.assign(indices, output=out_indices)
1132+
return out_values, out_indices
1133+
return values, indices
11251134
else:
11261135
helper = LayerHelper("top_k_v2", **locals())
11271136
inputs = {"X": [x]}

test/legacy_test/test_max_min_amax_amin_op.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,5 +280,76 @@ def init_case(self):
280280
self.keepdim = True
281281

282282

283+
class TestAmaxAminOutAPI(unittest.TestCase):
284+
def _run_api(self, api, x, case):
285+
out_buf = paddle.zeros([], dtype=x.dtype)
286+
out_buf.stop_gradient = False
287+
if case == 'return':
288+
y = api(x)
289+
elif case == 'input_out':
290+
api(x, out=out_buf)
291+
y = out_buf
292+
elif case == 'both_return':
293+
y = api(x, out=out_buf)
294+
elif case == 'both_input_out':
295+
_ = api(x, out=out_buf)
296+
y = out_buf
297+
else:
298+
raise AssertionError
299+
return y
300+
301+
def test_amax_out_in_dygraph(self):
302+
paddle.disable_static()
303+
x = paddle.to_tensor(
304+
np.array([[0.1, 0.9, 0.9, 0.9], [0.9, 0.9, 0.6, 0.7]]).astype(
305+
'float64'
306+
),
307+
stop_gradient=False,
308+
)
309+
ref = paddle._C_ops.amax(x, None, False)
310+
outs = []
311+
grads = []
312+
for case in ['return', 'input_out', 'both_return', 'both_input_out']:
313+
y = self._run_api(paddle.amax, x, case)
314+
np.testing.assert_allclose(
315+
y.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6
316+
)
317+
loss = (y * 2).mean()
318+
loss.backward()
319+
outs.append(y.numpy())
320+
grads.append(x.grad.numpy())
321+
x.clear_gradient()
322+
for i in range(1, 4):
323+
np.testing.assert_allclose(outs[0], outs[i], rtol=1e-6, atol=1e-6)
324+
np.testing.assert_allclose(grads[0], grads[i], rtol=1e-6, atol=1e-6)
325+
paddle.enable_static()
326+
327+
def test_amin_out_in_dygraph(self):
328+
paddle.disable_static()
329+
x = paddle.to_tensor(
330+
np.array([[0.2, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.7]]).astype(
331+
'float64'
332+
),
333+
stop_gradient=False,
334+
)
335+
ref = paddle._C_ops.amin(x, None, False)
336+
outs = []
337+
grads = []
338+
for case in ['return', 'input_out', 'both_return', 'both_input_out']:
339+
y = self._run_api(paddle.amin, x, case)
340+
np.testing.assert_allclose(
341+
y.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6
342+
)
343+
loss = (y * 2).mean()
344+
loss.backward()
345+
outs.append(y.numpy())
346+
grads.append(x.grad.numpy())
347+
x.clear_gradient()
348+
for i in range(1, 4):
349+
np.testing.assert_allclose(outs[0], outs[i], rtol=1e-6, atol=1e-6)
350+
np.testing.assert_allclose(grads[0], grads[i], rtol=1e-6, atol=1e-6)
351+
paddle.enable_static()
352+
353+
283354
if __name__ == '__main__':
284355
unittest.main()

test/legacy_test/test_top_k_op.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,70 @@ def test_check_grad(self):
6666
self.check_grad({'X'}, 'Out', check_cinn=self.check_cinn)
6767

6868

69+
class TestTopkOutAPI(unittest.TestCase):
70+
def test_out_in_dygraph(self):
71+
paddle.disable_static()
72+
x = paddle.to_tensor(
73+
np.array([[1, 4, 5, 7], [2, 6, 2, 5]]).astype('float32'),
74+
stop_gradient=False,
75+
)
76+
k = 2
77+
78+
def run_case(case):
79+
out_values = paddle.zeros_like(x[:, :k])
80+
out_indices = paddle.zeros([x.shape[0], k], dtype='int64')
81+
out_values.stop_gradient = False
82+
out_indices.stop_gradient = False
83+
84+
if case == 'return':
85+
values, indices = paddle.topk(x, k)
86+
elif case == 'input_out':
87+
paddle.topk(x, k, out=(out_values, out_indices))
88+
values, indices = out_values, out_indices
89+
elif case == 'both_return':
90+
values, indices = paddle.topk(
91+
x, k, out=(out_values, out_indices)
92+
)
93+
elif case == 'both_input_out':
94+
_ = paddle.topk(x, k, out=(out_values, out_indices))
95+
values, indices = out_values, out_indices
96+
else:
97+
raise AssertionError
98+
99+
ref_values, ref_indices = paddle._C_ops.topk(x, k, -1, True, True)
100+
np.testing.assert_allclose(
101+
values.numpy(), ref_values.numpy(), rtol=1e-6, atol=1e-6
102+
)
103+
np.testing.assert_allclose(
104+
indices.numpy(), ref_indices.numpy(), rtol=1e-6, atol=1e-6
105+
)
106+
107+
loss = (values.mean() + indices.float().mean()).mean()
108+
loss.backward()
109+
return values.numpy(), indices.numpy(), x.grad.numpy()
110+
111+
# run four scenarios
112+
v1, i1, g1 = run_case('return')
113+
x.clear_gradient()
114+
v2, i2, g2 = run_case('input_out')
115+
x.clear_gradient()
116+
v3, i3, g3 = run_case('both_return')
117+
x.clear_gradient()
118+
v4, i4, g4 = run_case('both_input_out')
119+
120+
np.testing.assert_allclose(v1, v2, rtol=1e-6, atol=1e-6)
121+
np.testing.assert_allclose(v1, v3, rtol=1e-6, atol=1e-6)
122+
np.testing.assert_allclose(v1, v4, rtol=1e-6, atol=1e-6)
123+
np.testing.assert_allclose(i1, i2, rtol=1e-6, atol=1e-6)
124+
np.testing.assert_allclose(i1, i3, rtol=1e-6, atol=1e-6)
125+
np.testing.assert_allclose(i1, i4, rtol=1e-6, atol=1e-6)
126+
np.testing.assert_allclose(g1, g2, rtol=1e-6, atol=1e-6)
127+
np.testing.assert_allclose(g1, g3, rtol=1e-6, atol=1e-6)
128+
np.testing.assert_allclose(g1, g4, rtol=1e-6, atol=1e-6)
129+
130+
paddle.enable_static()
131+
132+
69133
if __name__ == "__main__":
70134
paddle.enable_static()
71135
unittest.main()

0 commit comments

Comments
 (0)