diff --git a/python/paddle/amp/debugging.py b/python/paddle/amp/debugging.py index f42045e3f94604..f49c7b18d36d44 100644 --- a/python/paddle/amp/debugging.py +++ b/python/paddle/amp/debugging.py @@ -23,7 +23,7 @@ from paddle.base import core from paddle.base.framework import dygraph_only -from ..framework import LayerHelper, in_dynamic_mode +from ..framework import LayerHelper, in_dynamic_or_pir_mode __all__ = [ "DebugMode", @@ -372,7 +372,7 @@ def check_numerics( stack_height_limit = -1 output_dir = "" - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.check_numerics( tensor, op_type, diff --git a/test/legacy_test/test_nan_inf.py b/test/legacy_test/test_nan_inf.py index e6f6ddb7701440..6db010ece73e73 100644 --- a/test/legacy_test/test_nan_inf.py +++ b/test/legacy_test/test_nan_inf.py @@ -21,6 +21,8 @@ import numpy as np import paddle +from paddle.framework import in_pir_mode +from paddle.pir_utils import test_with_pir_api class TestNanInfBase(unittest.TestCase): @@ -299,6 +301,7 @@ def test_eager(self): debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL, ) + @test_with_pir_api def test_static(self): paddle.enable_static() shape = [8, 8] @@ -310,16 +313,22 @@ def test_static(self): x = paddle.static.data(name='x', shape=[8, 8], dtype="float32") y = paddle.static.data(name='y', shape=[8, 8], dtype="float32") out = paddle.add(x, y) - paddle.amp.debugging.check_numerics( - tensor=out, - op_type="elementwise_add", - var_name=out.name, - debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL, - ) + if in_pir_mode(): + paddle.amp.debugging.check_numerics( + tensor=out, + op_type="elementwise_add", + var_name=out.id, + debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL, + ) + else: + paddle.amp.debugging.check_numerics( + tensor=out, + op_type="elementwise_add", + var_name=out.name, + debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL, + ) exe = paddle.static.Executor(paddle.CPUPlace()) - exe.run( - main_program, feed={"x": x_np, "y": y_np}, fetch_list=[out.name] - ) + exe.run(main_program, feed={"x": x_np, "y": y_np}, fetch_list=[out]) paddle.disable_static()