diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index f1cdedb7a8ddf8..32fb639f9d3a01 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1950,7 +1950,7 @@ def __init__( def forward(self, x): weight = x - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.spectral_norm( weight, self.weight_u, diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index fa6f6e1bb80bcb..58eda30f44f811 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -26,6 +26,7 @@ Variable, default_main_program, in_dygraph_mode, + in_dynamic_or_pir_mode, name_scope, program_guard, static_only, @@ -3481,7 +3482,7 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): ) v.stop_gradient = True - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): return paddle._C_ops.spectral_norm(weight, u, v, dim, power_iters, eps) inputs = {'Weight': weight} diff --git a/test/legacy_test/test_layers.py b/test/legacy_test/test_layers.py index 75a746c72b5bb8..7fc809e2fe2446 100644 --- a/test/legacy_test/test_layers.py +++ b/test/legacy_test/test_layers.py @@ -986,6 +986,7 @@ def test_type(): _test_errors() + @test_with_pir_api def test_spectral_norm(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) diff --git a/test/legacy_test/test_spectral_norm_op.py b/test/legacy_test/test_spectral_norm_op.py index 912965f94755e4..299204483bd40d 100644 --- a/test/legacy_test/test_spectral_norm_op.py +++ b/test/legacy_test/test_spectral_norm_op.py @@ -20,6 +20,7 @@ import paddle from paddle import _C_ops from paddle.base.framework import Program, program_guard +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -84,7 +85,7 @@ def setUp(self): self.outputs = {"Out": output} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def initTestCase(self): self.weight_shape = (10, 12) @@ -116,6 +117,7 @@ def test_check_grad_ignore_uv(self): ['Weight'], 'Out', no_grad_set={"U", "V"}, + check_pir=True, ) def initTestCase(self): @@ -138,6 +140,7 @@ def initTestCase(self): class TestSpectralNormOpError(unittest.TestCase): + @test_with_pir_api def test_errors(self): with program_guard(Program(), Program()):