Skip to content

Commit 1875fc1

Browse files
committed
add out test
1 parent a457ac7 commit 1875fc1

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

test/legacy_test/test_reduce_op.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2676,6 +2676,51 @@ def test_dygraph_compatibility(self):
26762676
self.np_out, out.numpy(), rtol=1e-10
26772677
)
26782678

2679+
def test_dygraph_out(self):
2680+
def run_any(test_type):
2681+
x = paddle.to_tensor(self.np_input)
2682+
x.stop_gradient = False
2683+
out = (
2684+
paddle.zeros(self.np_out.shape)
2685+
if test_type in ["with_out", "both"]
2686+
else None
2687+
)
2688+
if test_type == "return":
2689+
out = paddle.any(x, axis=self.axis, keepdim=True)
2690+
elif test_type == "with_out":
2691+
paddle.any(x, axis=self.axis, keepdim=True, out=out)
2692+
elif test_type == "both":
2693+
out = paddle.any(x, axis=self.axis, keepdim=True, out=out)
2694+
else:
2695+
raise ValueError(f"Invalid test_mode: {test_type}")
2696+
2697+
expected = paddle._C_ops.any(x, self.axis, True)
2698+
np.testing.assert_array_equal(out.numpy(), expected.numpy())
2699+
loss = out.sum().astype('float32')
2700+
loss.backward()
2701+
return out, x.grad
2702+
2703+
def assert_outputs_equal(outputs, rtol: float = 1e-10):
2704+
for out in outputs[1:]:
2705+
np.testing.assert_allclose(
2706+
outputs[0].numpy(), out.numpy(), rtol=rtol
2707+
)
2708+
2709+
with dygraph_guard():
2710+
for place in self.places:
2711+
paddle.device.set_device(place)
2712+
out1, grad1 = run_any("return")
2713+
out2, grad2 = run_any("with_out")
2714+
out3, grad3 = run_any("both")
2715+
2716+
assert_outputs_equal([out1, out2, out3])
2717+
if (
2718+
grad1 is not None
2719+
and grad2 is not None
2720+
and grad3 is not None
2721+
):
2722+
assert_outputs_equal([grad1, grad2, grad3])
2723+
26792724
def test_static_compatibility(self):
26802725
with static_guard():
26812726
for place in self.places:

0 commit comments

Comments
 (0)