diff --git a/test/legacy_test/test_linear_interp_op.py b/test/legacy_test/test_linear_interp_op.py index 5c3b1d2814a129..f5bd1e7e103d10 100755 --- a/test/legacy_test/test_linear_interp_op.py +++ b/test/legacy_test/test_linear_interp_op.py @@ -20,7 +20,8 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api def linear_interp_np( @@ -325,8 +326,12 @@ def init_test_case(self): class TestLinearInterpOpError(unittest.TestCase): + @test_with_pir_api def test_error(self): - with program_guard(Program(), Program()): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): def input_shape_error(): x1 = paddle.static.data(name="x1", shape=[1], dtype="float32") @@ -369,6 +374,7 @@ def out_shape_error(): self.assertRaises(ValueError, input_shape_error) self.assertRaises(ValueError, data_format_error) self.assertRaises(ValueError, out_shape_error) + paddle.disable_static() if __name__ == "__main__": diff --git a/test/legacy_test/test_linear_interp_v2_op.py b/test/legacy_test/test_linear_interp_v2_op.py index b6a37f4500b009..97effe92de2ce9 100755 --- a/test/legacy_test/test_linear_interp_v2_op.py +++ b/test/legacy_test/test_linear_interp_v2_op.py @@ -20,8 +20,9 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core from paddle.nn.functional import interpolate +from paddle.pir_utils import test_with_pir_api def create_test_case0(self): @@ -528,9 +529,12 @@ def init_test_case(self): class TestLinearInterpOpError(unittest.TestCase): + @test_with_pir_api def test_error(self): with paddle_static_guard(): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): def input_shape_error(): x1 = paddle.static.data(