Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/paddle/amp/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from paddle.base import core
from paddle.base.framework import dygraph_only

from ..framework import LayerHelper, in_dynamic_mode

from ..framework import LayerHelper, in_dynamic_or_pir_mode
__all__ = [
"DebugMode",
"TensorCheckerConfig",
Expand Down Expand Up @@ -372,7 +371,7 @@ def check_numerics(
stack_height_limit = -1
output_dir = ""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.check_numerics(
tensor,
op_type,
Expand Down
27 changes: 18 additions & 9 deletions test/legacy_test/test_nan_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np

import paddle
from paddle.framework import in_pir_mode
from paddle.pir_utils import test_with_pir_api


class TestNanInfBase(unittest.TestCase):
Expand Down Expand Up @@ -299,6 +301,7 @@ def test_eager(self):
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
)

@test_with_pir_api

This comment was marked as resolved.

def test_static(self):
paddle.enable_static()
shape = [8, 8]
Expand All @@ -310,16 +313,22 @@ def test_static(self):
x = paddle.static.data(name='x', shape=[8, 8], dtype="float32")
y = paddle.static.data(name='y', shape=[8, 8], dtype="float32")
out = paddle.add(x, y)
paddle.amp.debugging.check_numerics(
tensor=out,
op_type="elementwise_add",
var_name=out.name,
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
)
if in_pir_mode():
paddle.amp.debugging.check_numerics(
tensor=out,
op_type="elementwise_add",
var_name=out.id,
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
)
else:
paddle.amp.debugging.check_numerics(
tensor=out,
op_type="elementwise_add",
var_name=out.name,
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(
main_program, feed={"x": x_np, "y": y_np}, fetch_list=[out.name]
)
exe.run(main_program, feed={"x": x_np, "y": y_np}, fetch_list=[out])
paddle.disable_static()


Expand Down