diff --git a/test/dygraph_to_static/test_loop.py b/test/dygraph_to_static/test_loop.py index fb2600b8ac2dc0..edaa9f32fb96fd 100644 --- a/test/dygraph_to_static/test_loop.py +++ b/test/dygraph_to_static/test_loop.py @@ -19,11 +19,13 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, + test_legacy_and_pt_and_pir, ) import paddle import paddle.nn.functional as F from paddle import base +from paddle.base.framework import use_pir_api from paddle.jit.dy2static.transformers.loop_transformer import NameVisitor from paddle.utils import gast @@ -251,6 +253,7 @@ def setUp(self): self.nested_for_loop_func = nested_for_loop_dyfunc + @test_legacy_and_pt_and_pir def test_loop_vars(self): for i in range(len(self.loop_funcs)): func = self.loop_funcs[i] @@ -266,6 +269,7 @@ def test_loop_vars(self): self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(create_var_names, self.create_var_names[i]) + @test_legacy_and_pt_and_pir def test_nested_loop_vars(self): func = self.nested_for_loop_func test_func = inspect.getsource(func) @@ -334,6 +338,7 @@ def _run(self, to_static): else: return ret + @test_legacy_and_pt_and_pir def test_ast_to_func(self): static_numpy = self._run_static() dygraph_numpy = self._run_dygraph() @@ -350,6 +355,13 @@ class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop): def _init_dyfunc(self): self.dyfunc = while_loop_dyfun_with_conflict_var + # This test raises an error about UndefinedVar in pir mode, + # it can be removed after the bug is fixed. + def test_ast_to_func(self): + static_numpy = self._run_static() + dygraph_numpy = self._run_dygraph() + np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05) + class TestTransformWhileLoopWithNone(TestTransformWhileLoop): def _init_dyfunc(self): @@ -407,6 +419,7 @@ def _run(self, to_static): ret = self.dyfunc(self.len) return ret.numpy() + @test_legacy_and_pt_and_pir def test_ast_to_func(self): np.testing.assert_allclose( self._run_dygraph(), self._run_static(), rtol=1e-05 @@ -432,6 +445,13 @@ class TestClassVarInForLoop(TestTransformForLoop): def _init_dyfunc(self): self.dyfunc = for_loop_class_var + # This test raises an error about UndefinedVar in pir mode, + # it can be removed after the bug is fixed. + def test_ast_to_func(self): + np.testing.assert_allclose( + self._run_dygraph(), self._run_static(), rtol=1e-05 + ) + class TestVarCreateInForLoop(TestTransformForLoop): def _init_dyfunc(self): @@ -463,6 +483,7 @@ def forward(self, x): class TestForLoopMeetDict(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_start(self): net = Net() model = paddle.jit.to_static( @@ -474,7 +495,9 @@ def test_start(self): ], ) temp_dir = tempfile.TemporaryDirectory() - paddle.jit.save(model, temp_dir.name) + # TODO(pir-save-load): Fix this after we support save/load in PIR + if not use_pir_api(): + paddle.jit.save(model, temp_dir.name) temp_dir.cleanup()