Skip to content

Commit 649868f

Browse files
authored
[Dy2stat] Fix the bug that loop_body_func may return single element (#31806)
Our old `loop_body` function may return single element when `loop_vars` just contains only 1 element, which can cause bug. The key point of this PR is forcing `loop_body` functions always return tuple.
1 parent e5f7a83 commit 649868f

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def get_for_stmt_nodes(self, node):
594594
# append return values for loop body
595595
body_stmts.append(
596596
gast.Return(value=generate_name_node(
597-
loop_var_names, ctx=gast.Load())))
597+
loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True)))
598598
body_func_node = gast.FunctionDef(
599599
name=unique_name.generate(FOR_BODY_PREFIX),
600600
args=gast.arguments(

python/paddle/fluid/dygraph/dygraph_to_static/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,15 @@ def get_attribute_full_name(node):
381381
return astor.to_source(gast.gast_to_ast(node)).strip()
382382

383383

384-
def generate_name_node(name_ids, ctx=gast.Load()):
384+
def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False):
385385
"""
386-
Generate list or gast.Tuple of ast.Name for Return statement.
386+
If name_ids is list or tuple or set with multiple strings, this function
387+
generates gast.Tuple of gast.Name.
388+
If the name_ids is single string or contains only 1 string, this function
389+
returns gast.Name if gen_tuple_if_single==False else returns gast.Tuple
390+
with only one gast.Name
391+
392+
This function is used at several gast.Return statements.
387393
"""
388394
if isinstance(name_ids, six.string_types):
389395
name_ids = [name_ids]
@@ -395,7 +401,7 @@ def generate_name_node(name_ids, ctx=gast.Load()):
395401
id=name_id, ctx=ctx, annotation=None, type_comment=None)
396402
for name_id in name_ids
397403
]
398-
if len(gast_names) == 1:
404+
if len(gast_names) == 1 and not gen_tuple_if_single:
399405
name_node = gast_names[0]
400406
else:
401407
name_node = gast.Tuple(elts=gast_names, ctx=ctx)

python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def for_iter_var_idx(x_array):
233233
return z
234234

235235

236+
# 17. for a,b,c in z: (a, b, c) is a tuple
236237
@paddle.jit.to_static
237238
def for_tuple_as_iter_var(x_array):
238239
x = paddle.to_tensor(x_array)
@@ -250,6 +251,7 @@ def for_tuple_as_iter_var(x_array):
250251
return a_result, b_result, c_result
251252

252253

254+
# 18. for t in enumerate(collection): t is tuple of (idx, element)
253255
@paddle.jit.to_static
254256
def for_tuple_as_enumerate_iter(x_array):
255257
x = paddle.to_tensor(x_array)
@@ -263,6 +265,7 @@ def for_tuple_as_enumerate_iter(x_array):
263265
return a_result
264266

265267

268+
# 19. for i, (a, b, c, d, e) in enumerate(collection): (a, b, c, d, e) is a tuple
266269
@paddle.jit.to_static
267270
def for_tuple_as_enumerate_value(x_array):
268271
x = paddle.to_tensor(x_array)
@@ -284,6 +287,23 @@ def for_tuple_as_enumerate_value(x_array):
284287
return a_result
285288

286289

290+
# 20. test for function in a class
291+
class ForwardContainsForLayer(paddle.nn.Layer):
292+
def __init__(self):
293+
super(ForwardContainsForLayer, self).__init__()
294+
self.high = 5
295+
self.low = 3
296+
297+
@paddle.jit.to_static
298+
def forward(self, x):
299+
# just for test case, x is useless in this method
300+
y = paddle.zeros([10, 2, 3])
301+
z = []
302+
for i in range(self.high - self.low):
303+
z.append(y[i].clone())
304+
return z
305+
306+
287307
class TestTransformBase(unittest.TestCase):
288308
def setUp(self):
289309
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
@@ -313,11 +333,11 @@ def get_static_output(self):
313333
class TestTransform(TestTransformBase):
314334
def transformed_result_compare(self):
315335
dy_outs = self.get_dygraph_output()
316-
if not isinstance(dy_outs, tuple):
336+
if not isinstance(dy_outs, (tuple, list)):
317337
dy_outs = (dy_outs, )
318338

319339
st_outs = self.get_static_output()
320-
if not isinstance(st_outs, tuple):
340+
if not isinstance(st_outs, (tuple, list)):
321341
st_outs = (st_outs, )
322342

323343
for x, y in zip(dy_outs, st_outs):
@@ -446,5 +466,10 @@ def set_test_func(self):
446466
self.dygraph_func = for_tuple_as_enumerate_value
447467

448468

469+
class TestForwardContainsForLayer(TestForIterVarNumpy):
470+
def set_test_func(self):
471+
self.dygraph_func = ForwardContainsForLayer()
472+
473+
449474
if __name__ == '__main__':
450475
unittest.main()

0 commit comments

Comments
 (0)