2020import paddle
2121from paddle import base
2222from paddle .base import core
23+ from paddle .pir_utils import test_with_pir_api
2324
2425
2526def reference_unique_consecutive (
@@ -203,6 +204,7 @@ def setUp(self):
203204 if core .is_compiled_with_cuda ():
204205 self .places .append (base .CUDAPlace (0 ))
205206
207+ @test_with_pir_api
206208 def check_static_result (self , place ):
207209 with base .program_guard (base .Program (), base .Program ()):
208210 paddle .enable_static ()
@@ -217,7 +219,6 @@ def check_static_result(self, place):
217219 x_np = np .random .randint (20 , size = 100 ).astype ("float32" )
218220 exe = base .Executor (place )
219221 fetches = exe .run (
220- base .default_main_program (),
221222 feed = {"input_x" : x_np },
222223 fetch_list = [result ],
223224 )
@@ -240,6 +241,7 @@ def setUp(self):
240241 if core .is_compiled_with_cuda ():
241242 self .places .append (base .CUDAPlace (0 ))
242243
244+ @test_with_pir_api
243245 def check_static_result (self , place ):
244246 with base .program_guard (base .Program (), base .Program ()):
245247 paddle .enable_static ()
@@ -256,7 +258,6 @@ def check_static_result(self, place):
256258 x_np = np .random .randint (20 , size = 100 ).astype ("float32" )
257259 exe = base .Executor (place )
258260 fetches = exe .run (
259- base .default_main_program (),
260261 feed = {"input_x" : x_np },
261262 fetch_list = [result ],
262263 )
@@ -281,6 +282,7 @@ def setUp(self):
281282 if core .is_compiled_with_cuda ():
282283 self .places .append (base .CUDAPlace (0 ))
283284
285+ @test_with_pir_api
284286 def check_static_result (self , place ):
285287 with base .program_guard (base .Program (), base .Program ()):
286288 paddle .enable_static ()
@@ -297,7 +299,6 @@ def check_static_result(self, place):
297299 x_np = np .random .randint (20 , size = 100 ).astype ("float32" )
298300 exe = base .Executor (place )
299301 fetches = exe .run (
300- base .default_main_program (),
301302 feed = {"input_x" : x_np },
302303 fetch_list = [result ],
303304 )
@@ -347,7 +348,7 @@ def setUp(self):
347348 }
348349
349350 def test_check_output (self ):
350- self .check_output ()
351+ self .check_output (check_pir = True )
351352
352353
353354if __name__ == "__main__" :
0 commit comments