Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def get_for_stmt_nodes(self, node):
# append return values for loop body
body_stmts.append(
gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load())))
loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_BODY_PREFIX),
args=gast.arguments(
Expand Down
12 changes: 9 additions & 3 deletions python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,15 @@ def get_attribute_full_name(node):
return astor.to_source(gast.gast_to_ast(node)).strip()


def generate_name_node(name_ids, ctx=gast.Load()):
def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False):
"""
Generate list or gast.Tuple of ast.Name for Return statement.
If name_ids is list or tuple or set with multiple strings, this function
generates gast.Tuple of gast.Name.
If the name_ids is single string or contains only 1 string, this function
returns gast.Name if gen_tuple_if_single==False else returns gast.Tuple
with only one gast.Name

This function is used at several gast.Return statements.
"""
if isinstance(name_ids, six.string_types):
name_ids = [name_ids]
Expand All @@ -395,7 +401,7 @@ def generate_name_node(name_ids, ctx=gast.Load()):
id=name_id, ctx=ctx, annotation=None, type_comment=None)
for name_id in name_ids
]
if len(gast_names) == 1:
if len(gast_names) == 1 and not gen_tuple_if_single:
name_node = gast_names[0]
else:
name_node = gast.Tuple(elts=gast_names, ctx=ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def for_iter_var_idx(x_array):
return z


# 17. for a,b,c in z: (a, b, c) is a tuple
@paddle.jit.to_static
def for_tuple_as_iter_var(x_array):
x = paddle.to_tensor(x_array)
Expand All @@ -250,6 +251,7 @@ def for_tuple_as_iter_var(x_array):
return a_result, b_result, c_result


# 18. for t in enumerate(collection): t is tuple of (idx, element)
@paddle.jit.to_static
def for_tuple_as_enumerate_iter(x_array):
x = paddle.to_tensor(x_array)
Expand All @@ -263,6 +265,7 @@ def for_tuple_as_enumerate_iter(x_array):
return a_result


# 19. for i, (a, b, c, d, e) in enumerate(collection): (a, b, c, d, e) is a tuple
@paddle.jit.to_static
def for_tuple_as_enumerate_value(x_array):
x = paddle.to_tensor(x_array)
Expand All @@ -284,6 +287,23 @@ def for_tuple_as_enumerate_value(x_array):
return a_result


# 20. test for function in a class
class ForwardContainsForLayer(paddle.nn.Layer):
def __init__(self):
super(ForwardContainsForLayer, self).__init__()
self.high = 5
self.low = 3

@paddle.jit.to_static
def forward(self, x):
# just for test case, x is useless in this method
y = paddle.zeros([10, 2, 3])
z = []
for i in range(self.high - self.low):
z.append(y[i].clone())
return z


class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
Expand Down Expand Up @@ -313,11 +333,11 @@ def get_static_output(self):
class TestTransform(TestTransformBase):
def transformed_result_compare(self):
dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, tuple):
if not isinstance(dy_outs, (tuple, list)):
dy_outs = (dy_outs, )

st_outs = self.get_static_output()
if not isinstance(st_outs, tuple):
if not isinstance(st_outs, (tuple, list)):
st_outs = (st_outs, )

for x, y in zip(dy_outs, st_outs):
Expand Down Expand Up @@ -446,5 +466,10 @@ def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_value


class TestForwardContainsForLayer(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = ForwardContainsForLayer()


if __name__ == '__main__':
unittest.main()