2020import paddle
2121from paddle import base
2222from paddle .base import core
23+ from paddle .pir_utils import test_with_pir_api
2324
2425paddle .enable_static ()
2526
@@ -113,13 +114,11 @@ def get_reduce_dims(x, y):
113114 return x_grad , y_grad
114115
115116 def test_check_output (self ):
116- self .check_output ()
117+ self .check_output (check_pir = True )
117118
118119 def test_check_grad (self ):
119120 self .check_grad (
120- ["X" , "Y" ],
121- "Out" ,
122- user_defined_grads = self .gradient ,
121+ ["X" , "Y" ], "Out" , user_defined_grads = self .gradient , check_pir = True
123122 )
124123
125124
@@ -244,10 +243,11 @@ def init_data_type(self):
244243 'float32' if core .is_compiled_with_rocm () else 'float64'
245244 )
246245
246+ @test_with_pir_api
247247 def test_api (self ):
248248 self .init_data_type ()
249- main_program = base .Program ()
250- startup_program = base .Program ()
249+ main_program = paddle . static .Program ()
250+ startup_program = paddle . static .Program ()
251251 with base .program_guard (main_program , startup_program ):
252252 x = paddle .static .data (
253253 name = 'x' , shape = [2 , 3 , 4 , 5 ], dtype = self .data_type
@@ -266,7 +266,7 @@ def test_api(self):
266266 )
267267 exe = base .Executor (place )
268268 out = exe .run (
269- base . default_main_program () ,
269+ main_program ,
270270 feed = {'x' : x_i , 'y' : y_i },
271271 fetch_list = [result ],
272272 )
0 commit comments