Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle
import paddle.nn.functional as F
from paddle import base
from paddle.base.framework import use_pir_api
from paddle.jit.dy2static.transformers.loop_transformer import NameVisitor
from paddle.utils import gast

Expand Down Expand Up @@ -251,6 +253,7 @@ def setUp(self):

self.nested_for_loop_func = nested_for_loop_dyfunc

@test_legacy_and_pt_and_pir
def test_loop_vars(self):
for i in range(len(self.loop_funcs)):
func = self.loop_funcs[i]
Expand All @@ -266,6 +269,7 @@ def test_loop_vars(self):
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])

@test_legacy_and_pt_and_pir
def test_nested_loop_vars(self):
func = self.nested_for_loop_func
test_func = inspect.getsource(func)
Expand Down Expand Up @@ -334,6 +338,7 @@ def _run(self, to_static):
else:
return ret

@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
Expand All @@ -350,6 +355,13 @@ class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfun_with_conflict_var

# This test raises an error about UndefinedVar in pir mode,
# it can be removed after the bug is fixed.
def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05)


class TestTransformWhileLoopWithNone(TestTransformWhileLoop):
def _init_dyfunc(self):
Expand Down Expand Up @@ -407,6 +419,7 @@ def _run(self, to_static):
ret = self.dyfunc(self.len)
return ret.numpy()

@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
np.testing.assert_allclose(
self._run_dygraph(), self._run_static(), rtol=1e-05
Expand All @@ -432,6 +445,13 @@ class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_class_var

# This test raises an error about UndefinedVar in pir mode,
# it can be removed after the bug is fixed.
def test_ast_to_func(self):
np.testing.assert_allclose(
self._run_dygraph(), self._run_static(), rtol=1e-05
)


class TestVarCreateInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
Expand Down Expand Up @@ -463,6 +483,7 @@ def forward(self, x):


class TestForLoopMeetDict(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_start(self):
net = Net()
model = paddle.jit.to_static(
Expand All @@ -474,7 +495,9 @@ def test_start(self):
],
)
temp_dir = tempfile.TemporaryDirectory()
paddle.jit.save(model, temp_dir.name)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(model, temp_dir.name)
temp_dir.cleanup()


Expand Down