diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index a464becfec4933..e16aafbd399d94 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -24,24 +24,20 @@ from ifelse_simple_func import dyfunc_with_if_else import paddle -from paddle import base from paddle.base import core -from paddle.jit import to_static from paddle.jit.dy2static.utils import Dygraph2StaticException SEED = 2020 np.random.seed(SEED) -@to_static(full_graph=True) def test_return_base(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) return x -@to_static(full_graph=True) def test_inside_func_base(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) def inner_func(x): return x @@ -49,9 +45,8 @@ def inner_func(x): return inner_func(x) -@to_static(full_graph=True) def test_return_if(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) if x < 0: x -= 1 return -x @@ -59,9 +54,8 @@ def test_return_if(x): return x -@to_static(full_graph=True) def test_return_if_else(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) if x > 0: x += 10086 return x @@ -72,9 +66,8 @@ def test_return_if_else(x): x -= 8888 # useless statement to test our code can handle it. -@to_static(full_graph=True) def test_return_in_while(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0) while i < 10: i += 1 @@ -85,9 +78,8 @@ def test_return_in_while(x): return x -@to_static(full_graph=True) def test_return_in_for(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) for i in range(10): if i <= 4: x += 1 @@ -97,88 +89,78 @@ def test_return_in_for(x): return x - 1 -@to_static(full_graph=True) def test_recursive_return(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) return dyfunc_with_if_else(x) -@to_static(full_graph=True) def test_return_different_length_if_body(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = x + 1 if x > 0: - # x = to_variable(np.ones(1)) so it will return here + # x = paddle.to_tensor(np.ones(1)) so it will return here return x, y else: return x -@to_static(full_graph=True) def test_return_different_length_else(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = x + 1 if x < 0: return x, y else: - # x = to_variable(np.ones(1)) so it will return here + # x = paddle.to_tensor(np.ones(1)) so it will return here return x -@to_static(full_graph=True) def test_no_return(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = x + 1 -@to_static(full_graph=True) def test_return_none(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = x + 1 if x > 0: - # x = to_variable(np.ones(1)) so it will return here + # x = paddle.to_tensor(np.ones(1)) so it will return here return None else: return x, y -@to_static(full_graph=True) def test_return_no_variable(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) y = x + 1 if x < 0: return x, y else: - # x = to_variable(np.ones(1)) so it will return here + # x = paddle.to_tensor(np.ones(1)) so it will return here return -@to_static(full_graph=True) def test_return_list_one_value(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) x += 1 return [x] -@to_static(full_graph=True) def test_return_list_many_values(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) x += 1 y = x * 2 z = x * x return [x, y, z] -@to_static(full_graph=True) def test_return_tuple_one_value(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) x += 1 return (x,) -@to_static(full_graph=True) def test_return_tuple_many_values(x): - x = base.dygraph.to_variable(x) + x = paddle.to_tensor(x) x += 1 y = x * 2 z = x * x @@ -194,7 +176,6 @@ def inner_func(x): return y -@to_static(full_graph=True) def test_return_without_paddle_cond(x): # y shape is [10] y = paddle.ones([10]) @@ -218,7 +199,6 @@ def diff_return_hepler(x): return two_value(x) -@to_static(full_graph=True) def test_diff_return(x): x = paddle.to_tensor(x) y, z = diff_return_hepler(x) @@ -227,7 +207,6 @@ def test_diff_return(x): return y, z -@to_static(full_graph=True) def test_return_if_else_2(x): rr = 0 if True: @@ -237,7 +216,6 @@ def test_return_if_else_2(x): a = 0 -@to_static(full_graph=True) def test_return_in_while_2(x): while True: a = 12 @@ -245,7 +223,6 @@ def test_return_in_while_2(x): return 10 -@to_static(full_graph=True) def test_return_in_for_2(x): a = 12 for i in range(10): @@ -253,7 +230,6 @@ def test_return_in_for_2(x): return 10 -@to_static(full_graph=True) def test_return_nested(x): def func(): rr = 0 @@ -272,24 +248,17 @@ def func(): class TestReturnBase(Dy2StTestBase): def setUp(self): self.input = np.ones(1).astype('int32') - self.place = ( - base.CUDAPlace(0) - if base.is_compiled_with_cuda() - else base.CPUPlace() - ) - self.init_dygraph_func() def init_dygraph_func(self): self.dygraph_func = test_return_base def _run(self): - with base.dygraph.guard(): - res = self.dygraph_func(self.input) - if isinstance(res, (tuple, list)): - return tuple(r.numpy() for r in res) - elif isinstance(res, core.eager.Tensor): - return res.numpy() - return res + res = paddle.jit.to_static(self.dygraph_func)(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res def _test_value_impl(self): with enable_to_static_guard(False): @@ -309,6 +278,7 @@ def _test_value_impl(self): @test_ast_only def test_transformed_static_result(self): + self.init_dygraph_func() if hasattr(self, "error"): with self.assertRaisesRegex(Dygraph2StaticException, self.error): self._test_value_impl() @@ -324,30 +294,22 @@ def init_dygraph_func(self): class TestReturnIf(Dy2StTestBase): def setUp(self): self.input = np.ones(1).astype('int32') - self.place = ( - base.CUDAPlace(0) - if base.is_compiled_with_cuda() - else base.CPUPlace() - ) - self.init_dygraph_func() def init_dygraph_func(self): self.dygraph_func = test_return_if def _run(self): - with base.dygraph.guard(): - res = self.dygraph_func(self.input) - if isinstance(res, (tuple, list)): - return tuple(r.numpy() for r in res) - elif isinstance(res, core.eager.Tensor): - return res.numpy() - return res + res = paddle.jit.to_static(self.dygraph_func)(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res def _test_value_impl(self): with enable_to_static_guard(False): dygraph_res = self._run() - with enable_to_static_guard(True): - static_res = self._run() + static_res = self._run() if isinstance(dygraph_res, tuple): self.assertTrue(isinstance(static_res, tuple)) self.assertEqual(len(dygraph_res), len(static_res)) @@ -364,6 +326,7 @@ def _test_value_impl(self): @test_legacy_only @test_ast_only def test_transformed_static_result(self): + self.init_dygraph_func() if hasattr(self, "error"): with self.assertRaisesRegex(Dygraph2StaticException, self.error): self._test_value_impl() @@ -384,30 +347,22 @@ def init_dygraph_func(self): class TestReturnInWhile(Dy2StTestBase): def setUp(self): self.input = np.ones(1).astype('int32') - self.place = ( - base.CUDAPlace(0) - if base.is_compiled_with_cuda() - else base.CPUPlace() - ) - self.init_dygraph_func() def init_dygraph_func(self): self.dygraph_func = test_return_in_while def _run(self): - with base.dygraph.guard(): - res = self.dygraph_func(self.input) - if isinstance(res, (tuple, list)): - return tuple(r.numpy() for r in res) - elif isinstance(res, core.eager.Tensor): - return res.numpy() - return res + res = paddle.jit.to_static(self.dygraph_func)(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res def _test_value_impl(self): with enable_to_static_guard(False): dygraph_res = self._run() - with enable_to_static_guard(True): - static_res = self._run() + static_res = self._run() if isinstance(dygraph_res, tuple): self.assertTrue(isinstance(static_res, tuple)) self.assertEqual(len(dygraph_res), len(static_res)) @@ -424,6 +379,7 @@ def _test_value_impl(self): @test_legacy_only @test_ast_only def test_transformed_static_result(self): + self.init_dygraph_func() if hasattr(self, "error"): with self.assertRaisesRegex(Dygraph2StaticException, self.error): self._test_value_impl() @@ -439,30 +395,22 @@ def init_dygraph_func(self): class TestReturnIfElse(Dy2StTestBase): def setUp(self): self.input = np.ones(1).astype('int32') - self.place = ( - base.CUDAPlace(0) - if base.is_compiled_with_cuda() - else base.CPUPlace() - ) - self.init_dygraph_func() def init_dygraph_func(self): self.dygraph_func = test_return_if_else def _run(self): - with base.dygraph.guard(): - res = self.dygraph_func(self.input) - if isinstance(res, (tuple, list)): - return tuple(r.numpy() for r in res) - elif isinstance(res, core.eager.Tensor): - return res.numpy() - return res + res = paddle.jit.to_static(self.dygraph_func)(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res def _test_value_impl(self): with enable_to_static_guard(False): dygraph_res = self._run() - with enable_to_static_guard(True): - static_res = self._run() + static_res = self._run() if isinstance(dygraph_res, tuple): self.assertTrue(isinstance(static_res, tuple)) self.assertEqual(len(dygraph_res), len(static_res)) @@ -479,6 +427,7 @@ def _test_value_impl(self): @test_legacy_only @test_ast_only def test_transformed_static_result(self): + self.init_dygraph_func() if hasattr(self, "error"): with self.assertRaisesRegex(Dygraph2StaticException, self.error): self._test_value_impl()