1919
2020import paddle
2121from paddle import base
22- from paddle .base import Program , core , program_guard
22+ from paddle .base import core
23+ from paddle .pir_utils import test_with_pir_api
2324
2425
2526class TestCrossOp (OpTest ):
@@ -47,10 +48,10 @@ def init_output(self):
4748 self .outputs = {'Out' : np .array (z_list ).reshape (self .shape )}
4849
4950 def test_check_output (self ):
50- self .check_output ()
51+ self .check_output (check_pir = True )
5152
5253 def test_check_grad_normal (self ):
53- self .check_grad (['X' , 'Y' ], 'Out' )
54+ self .check_grad (['X' , 'Y' ], 'Out' , check_pir = True )
5455
5556
5657class TestCrossOpCase1 (TestCrossOp ):
@@ -116,13 +117,15 @@ def test_check_output(self):
116117 if core .is_compiled_with_cuda ():
117118 place = core .CUDAPlace (0 )
118119 if core .is_bfloat16_supported (place ):
119- self .check_output_with_place (place )
120+ self .check_output_with_place (place , check_pir = True )
120121
121122 def test_check_grad_normal (self ):
122123 if core .is_compiled_with_cuda ():
123124 place = core .CUDAPlace (0 )
124125 if core .is_bfloat16_supported (place ):
125- self .check_grad_with_place (place , ['X' , 'Y' ], 'Out' )
126+ self .check_grad_with_place (
127+ place , ['X' , 'Y' ], 'Out' , check_pir = True
128+ )
126129
127130
128131class TestCrossAPI (unittest .TestCase ):
@@ -134,43 +137,56 @@ def input_data(self):
134137 [[1.0 , 1.0 , 1.0 ], [1.0 , 1.0 , 1.0 ], [1.0 , 1.0 , 1.0 ]]
135138 ).astype ('float32' )
136139
140+ @test_with_pir_api
137141 def test_cross_api (self ):
138142 self .input_data ()
139143
144+ main = paddle .static .Program ()
145+ startup = paddle .static .Program ()
140146 # case 1:
141- with program_guard (Program (), Program () ):
147+ with paddle . static . program_guard (main , startup ):
142148 x = paddle .static .data (name = 'x' , shape = [- 1 , 3 ], dtype = "float32" )
143149 y = paddle .static .data (name = 'y' , shape = [- 1 , 3 ], dtype = "float32" )
144150 z = paddle .cross (x , y , axis = 1 )
145151 exe = base .Executor (base .CPUPlace ())
146152 (res ,) = exe .run (
153+ main ,
147154 feed = {'x' : self .data_x , 'y' : self .data_y },
148- fetch_list = [z . name ],
155+ fetch_list = [z ],
149156 return_numpy = False ,
150157 )
151158 expect_out = np .array (
152159 [[0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 ]]
153160 )
154161 np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
155162
163+ main = paddle .static .Program ()
164+ startup = paddle .static .Program ()
156165 # case 2:
157- with program_guard (Program (), Program () ):
166+ with paddle . static . program_guard (main , startup ):
158167 x = paddle .static .data (name = 'x' , shape = [- 1 , 3 ], dtype = "float32" )
159168 y = paddle .static .data (name = 'y' , shape = [- 1 , 3 ], dtype = "float32" )
160169 z = paddle .cross (x , y )
161170 exe = base .Executor (base .CPUPlace ())
162171 (res ,) = exe .run (
172+ main ,
163173 feed = {'x' : self .data_x , 'y' : self .data_y },
164- fetch_list = [z . name ],
174+ fetch_list = [z ],
165175 return_numpy = False ,
166176 )
167177 expect_out = np .array (
168178 [[- 1.0 , - 1.0 , - 1.0 ], [2.0 , 2.0 , 2.0 ], [- 1.0 , - 1.0 , - 1.0 ]]
169179 )
170180 np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
171181
172- # case 3:
173- with program_guard (Program (), Program ()):
182+ def test_cross_api1 (self ):
183+ self .input_data ()
184+
185+ main = paddle .static .Program ()
186+ startup = paddle .static .Program ()
187+
188+ # case 1:
189+ with paddle .static .program_guard (main , startup ):
174190 x = paddle .static .data (name = "x" , shape = [- 1 , 3 ], dtype = "float32" )
175191 y = paddle .static .data (name = 'y' , shape = [- 1 , 3 ], dtype = 'float32' )
176192
0 commit comments