2121import numpy as np
2222
2323import paddle
24+ from paddle .framework import in_pir_mode
25+ from paddle .pir_utils import test_with_pir_api
2426
2527
2628class TestNanInfBase (unittest .TestCase ):
@@ -299,6 +301,7 @@ def test_eager(self):
299301 debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
300302 )
301303
304+ @test_with_pir_api
302305 def test_static (self ):
303306 paddle .enable_static ()
304307 shape = [8 , 8 ]
@@ -310,16 +313,22 @@ def test_static(self):
310313 x = paddle .static .data (name = 'x' , shape = [8 , 8 ], dtype = "float32" )
311314 y = paddle .static .data (name = 'y' , shape = [8 , 8 ], dtype = "float32" )
312315 out = paddle .add (x , y )
313- paddle .amp .debugging .check_numerics (
314- tensor = out ,
315- op_type = "elementwise_add" ,
316- var_name = out .name ,
317- debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
318- )
316+ if in_pir_mode ():
317+ paddle .amp .debugging .check_numerics (
318+ tensor = out ,
319+ op_type = "elementwise_add" ,
320+ var_name = out .id ,
321+ debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
322+ )
323+ else :
324+ paddle .amp .debugging .check_numerics (
325+ tensor = out ,
326+ op_type = "elementwise_add" ,
327+ var_name = out .name ,
328+ debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
329+ )
319330 exe = paddle .static .Executor (paddle .CPUPlace ())
320- exe .run (
321- main_program , feed = {"x" : x_np , "y" : y_np }, fetch_list = [out .name ]
322- )
331+ exe .run (main_program , feed = {"x" : x_np , "y" : y_np }, fetch_list = [out ])
323332 paddle .disable_static ()
324333
325334
0 commit comments