Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 51 additions & 102 deletions test/dygraph_to_static/test_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,38 @@
from ifelse_simple_func import dyfunc_with_if_else

import paddle
from paddle import base
from paddle.base import core
from paddle.jit import to_static
from paddle.jit.dy2static.utils import Dygraph2StaticException

SEED = 2020
np.random.seed(SEED)


@to_static(full_graph=True)
def test_return_base(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
return x


@to_static(full_graph=True)
def test_inside_func_base(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)

def inner_func(x):
return x

return inner_func(x)


@to_static(full_graph=True)
def test_return_if(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
if x < 0:
x -= 1
return -x
x += 3
return x


@to_static(full_graph=True)
def test_return_if_else(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
if x > 0:
x += 10086
return x
Expand All @@ -72,9 +66,8 @@ def test_return_if_else(x):
x -= 8888 # useless statement to test our code can handle it.


@to_static(full_graph=True)
def test_return_in_while(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
i += 1
Expand All @@ -85,9 +78,8 @@ def test_return_in_while(x):
return x


@to_static(full_graph=True)
def test_return_in_for(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
for i in range(10):
if i <= 4:
x += 1
Expand All @@ -97,88 +89,78 @@ def test_return_in_for(x):
return x - 1


@to_static(full_graph=True)
def test_recursive_return(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
return dyfunc_with_if_else(x)


@to_static(full_graph=True)
def test_return_different_length_if_body(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
y = x + 1
if x > 0:
# x = to_variable(np.ones(1)) so it will return here
# x = paddle.to_tensor(np.ones(1)) so it will return here
return x, y
else:
return x


@to_static(full_graph=True)
def test_return_different_length_else(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
y = x + 1
if x < 0:
return x, y
else:
# x = to_variable(np.ones(1)) so it will return here
# x = paddle.to_tensor(np.ones(1)) so it will return here
return x


@to_static(full_graph=True)
def test_no_return(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
y = x + 1


@to_static(full_graph=True)
def test_return_none(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
y = x + 1
if x > 0:
# x = to_variable(np.ones(1)) so it will return here
# x = paddle.to_tensor(np.ones(1)) so it will return here
return None
else:
return x, y


@to_static(full_graph=True)
def test_return_no_variable(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
y = x + 1
if x < 0:
return x, y
else:
# x = to_variable(np.ones(1)) so it will return here
# x = paddle.to_tensor(np.ones(1)) so it will return here
return


@to_static(full_graph=True)
def test_return_list_one_value(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
x += 1
return [x]


@to_static(full_graph=True)
def test_return_list_many_values(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
x += 1
y = x * 2
z = x * x
return [x, y, z]


@to_static(full_graph=True)
def test_return_tuple_one_value(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
x += 1
return (x,)


@to_static(full_graph=True)
def test_return_tuple_many_values(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
x += 1
y = x * 2
z = x * x
Expand All @@ -194,7 +176,6 @@ def inner_func(x):
return y


@to_static(full_graph=True)
def test_return_without_paddle_cond(x):
# y shape is [10]
y = paddle.ones([10])
Expand All @@ -218,7 +199,6 @@ def diff_return_hepler(x):
return two_value(x)


@to_static(full_graph=True)
def test_diff_return(x):
x = paddle.to_tensor(x)
y, z = diff_return_hepler(x)
Expand All @@ -227,7 +207,6 @@ def test_diff_return(x):
return y, z


@to_static(full_graph=True)
def test_return_if_else_2(x):
rr = 0
if True:
Expand All @@ -237,23 +216,20 @@ def test_return_if_else_2(x):
a = 0


@to_static(full_graph=True)
def test_return_in_while_2(x):
while True:
a = 12
return 12
return 10


@to_static(full_graph=True)
def test_return_in_for_2(x):
a = 12
for i in range(10):
return 12
return 10


@to_static(full_graph=True)
def test_return_nested(x):
def func():
rr = 0
Expand All @@ -272,24 +248,17 @@ def func():
class TestReturnBase(Dy2StTestBase):
def setUp(self):
self.input = np.ones(1).astype('int32')
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
)
self.init_dygraph_func()

def init_dygraph_func(self):
self.dygraph_func = test_return_base

def _run(self):
with base.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res
res = paddle.jit.to_static(self.dygraph_func)(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res

def _test_value_impl(self):
with enable_to_static_guard(False):
Expand All @@ -309,6 +278,7 @@ def _test_value_impl(self):

@test_ast_only
def test_transformed_static_result(self):
self.init_dygraph_func()
if hasattr(self, "error"):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
self._test_value_impl()
Expand All @@ -324,30 +294,22 @@ def init_dygraph_func(self):
class TestReturnIf(Dy2StTestBase):
def setUp(self):
self.input = np.ones(1).astype('int32')
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
)
self.init_dygraph_func()

def init_dygraph_func(self):
self.dygraph_func = test_return_if

def _run(self):
with base.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res
res = paddle.jit.to_static(self.dygraph_func)(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res

def _test_value_impl(self):
with enable_to_static_guard(False):
dygraph_res = self._run()
with enable_to_static_guard(True):
static_res = self._run()
static_res = self._run()
if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res))
Expand All @@ -364,6 +326,7 @@ def _test_value_impl(self):
@test_legacy_only
@test_ast_only
def test_transformed_static_result(self):
self.init_dygraph_func()
if hasattr(self, "error"):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
self._test_value_impl()
Expand All @@ -384,30 +347,22 @@ def init_dygraph_func(self):
class TestReturnInWhile(Dy2StTestBase):
def setUp(self):
self.input = np.ones(1).astype('int32')
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
)
self.init_dygraph_func()

def init_dygraph_func(self):
self.dygraph_func = test_return_in_while

def _run(self):
with base.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res
res = paddle.jit.to_static(self.dygraph_func)(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res

def _test_value_impl(self):
with enable_to_static_guard(False):
dygraph_res = self._run()
with enable_to_static_guard(True):
static_res = self._run()
static_res = self._run()
if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res))
Expand All @@ -424,6 +379,7 @@ def _test_value_impl(self):
@test_legacy_only
@test_ast_only
def test_transformed_static_result(self):
self.init_dygraph_func()
if hasattr(self, "error"):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
self._test_value_impl()
Expand All @@ -439,30 +395,22 @@ def init_dygraph_func(self):
class TestReturnIfElse(Dy2StTestBase):
def setUp(self):
self.input = np.ones(1).astype('int32')
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
)
self.init_dygraph_func()

def init_dygraph_func(self):
self.dygraph_func = test_return_if_else

def _run(self):
with base.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res
res = paddle.jit.to_static(self.dygraph_func)(self.input)
if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.eager.Tensor):
return res.numpy()
return res

def _test_value_impl(self):
with enable_to_static_guard(False):
dygraph_res = self._run()
with enable_to_static_guard(True):
static_res = self._run()
static_res = self._run()
if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res))
Expand All @@ -479,6 +427,7 @@ def _test_value_impl(self):
@test_legacy_only
@test_ast_only
def test_transformed_static_result(self):
self.init_dygraph_func()
if hasattr(self, "error"):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
self._test_value_impl()
Expand Down