@@ -2676,6 +2676,51 @@ def test_dygraph_compatibility(self):
26762676 self .np_out , out .numpy (), rtol = 1e-10
26772677 )
26782678
2679+ def test_dygraph_out (self ):
2680+ def run_any (test_type ):
2681+ x = paddle .to_tensor (self .np_input )
2682+ x .stop_gradient = False
2683+ out = (
2684+ paddle .zeros (self .np_out .shape )
2685+ if test_type in ["with_out" , "both" ]
2686+ else None
2687+ )
2688+ if test_type == "return" :
2689+ out = paddle .any (x , axis = self .axis , keepdim = True )
2690+ elif test_type == "with_out" :
2691+ paddle .any (x , axis = self .axis , keepdim = True , out = out )
2692+ elif test_type == "both" :
2693+ out = paddle .any (x , axis = self .axis , keepdim = True , out = out )
2694+ else :
2695+ raise ValueError (f"Invalid test_mode: { test_type } " )
2696+
2697+ expected = paddle ._C_ops .any (x , self .axis , True )
2698+ np .testing .assert_array_equal (out .numpy (), expected .numpy ())
2699+ loss = out .sum ().astype ('float32' )
2700+ loss .backward ()
2701+ return out , x .grad
2702+
2703+ def assert_outputs_equal (outputs , rtol : float = 1e-10 ):
2704+ for out in outputs [1 :]:
2705+ np .testing .assert_allclose (
2706+ outputs [0 ].numpy (), out .numpy (), rtol = rtol
2707+ )
2708+
2709+ with dygraph_guard ():
2710+ for place in self .places :
2711+ paddle .device .set_device (place )
2712+ out1 , grad1 = run_any ("return" )
2713+ out2 , grad2 = run_any ("with_out" )
2714+ out3 , grad3 = run_any ("both" )
2715+
2716+ assert_outputs_equal ([out1 , out2 , out3 ])
2717+ if (
2718+ grad1 is not None
2719+ and grad2 is not None
2720+ and grad3 is not None
2721+ ):
2722+ assert_outputs_equal ([grad1 , grad2 , grad3 ])
2723+
26792724 def test_static_compatibility (self ):
26802725 with static_guard ():
26812726 for place in self .places :
0 commit comments