diff --git a/python/paddle/jit/dy2static/transformers/break_continue_transformer.py b/python/paddle/jit/dy2static/transformers/break_continue_transformer.py index 39ea02f00db6ab..ab85500b5de727 100644 --- a/python/paddle/jit/dy2static/transformers/break_continue_transformer.py +++ b/python/paddle/jit/dy2static/transformers/break_continue_transformer.py @@ -120,9 +120,13 @@ def transform(self): self.visit(self.root) def visit_Break(self, node): + function_def_node_index = _find_ancestor_function_def_index( + self.ancestor_nodes + ) loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes) assert loop_node_index != -1, "SyntaxError: 'break' outside loop" loop_node = self.ancestor_nodes[loop_node_index] + function_def_node = self.ancestor_nodes[function_def_node_index] # 1. Map the 'break/continue' stmt with an unique boolean variable V. variable_name = unique_name.generate(BREAK_NAME_PREFIX) @@ -140,7 +144,7 @@ def visit_Break(self, node): # 4. For 'break' add break into condition of the loop. assign_false_node = create_bool_node(variable_name, False) - self._add_stmt_before_cur_node(loop_node_index, assign_false_node) + function_def_node.body.insert(0, assign_false_node) cond_var_node = gast.UnaryOp( op=gast.Not(), @@ -164,9 +168,13 @@ def visit_Break(self, node): for_to_while.transform() def visit_Continue(self, node): + function_def_node_index = _find_ancestor_function_def_index( + self.ancestor_nodes + ) loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes) assert loop_node_index != -1, "SyntaxError: 'continue' outside loop" loop_node = self.ancestor_nodes[loop_node_index] + function_def_node = self.ancestor_nodes[function_def_node_index] # 1. Map the 'break/continue' stmt with an unique boolean variable V. variable_name = unique_name.generate(CONTINUE_NAME_PREFIX) @@ -185,6 +193,9 @@ def visit_Continue(self, node): # 4. For 'continue', set continue to False at the beginning of each loop assign_false_node = create_bool_node(variable_name, False) loop_node.body.insert(0, assign_false_node) + # Add a same assign statement to the beginning of function body to avoid + # generate the UndefinedVar + function_def_node.body.insert(0, assign_false_node) def _remove_stmts_after_break_continue( self, break_continue_node, break_continue_name, loop_node_index @@ -298,6 +309,13 @@ def _find_ancestor_loop_index(node, ancestor_nodes): return -1 +def _find_ancestor_function_def_index(ancestor_nodes): + for i in range(len(ancestor_nodes) - 1, -1, -1): + if isinstance(ancestor_nodes[i], gast.FunctionDef): + return i + return -1 + + class BreakTransformOptimizer(BaseNodeVisitor): """ In specific pattern, the transformed code could be optimized by joining the diff --git a/python/paddle/jit/dy2static/transformers/return_transformer.py b/python/paddle/jit/dy2static/transformers/return_transformer.py index 6aafe1a9912153..fe8cd2cde15e61 100644 --- a/python/paddle/jit/dy2static/transformers/return_transformer.py +++ b/python/paddle/jit/dy2static/transformers/return_transformer.py @@ -13,6 +13,7 @@ # limitations under the License. from paddle.base import unique_name +from paddle.jit.dy2static.variable_trans_func import create_bool_node from paddle.utils import gast from ..utils import ( @@ -232,6 +233,7 @@ def transform(self): return node # Prepend initialization of final return and append final return statement + return_flag_names = self.return_name value_name = self.return_value_name if value_name is not None: node.body.append( @@ -257,6 +259,10 @@ def transform(self): ) node.body.insert(0, assign_return_value_node) + for return_flag_name in return_flag_names: + assign_return_flag_node = create_bool_node(return_flag_name, False) + node.body.insert(0, assign_return_flag_node) + # Prepend no value placeholders return node diff --git a/python/paddle/jit/dy2static/variable_trans_func.py b/python/paddle/jit/dy2static/variable_trans_func.py index 13bec054823158..0bbb388ae80f10 100644 --- a/python/paddle/jit/dy2static/variable_trans_func.py +++ b/python/paddle/jit/dy2static/variable_trans_func.py @@ -14,6 +14,8 @@ import paddle from paddle.base.framework import Variable +from paddle.framework import use_pir_api +from paddle.pir import Value from paddle.utils import gast, is_sequence, map_structure from .utils import UndefinedVar, create_undefined_variable @@ -36,7 +38,7 @@ def to_static_variable(x): return paddle.full(shape=[], dtype='float64', fill_value=x) if isinstance(x, int): return paddle.full(shape=[], dtype='int64', fill_value=x) - if isinstance(x, UndefinedVar) or x is None: + if not use_pir_api() and (isinstance(x, UndefinedVar) or x is None): """ for early return case, we need a variable to represent None, current we use data_layer_not_check. """ @@ -50,8 +52,8 @@ def create_bool_as_type(x, value=True): ''' Create a bool variable, which type is the same as x. ''' - if isinstance(x, Variable): - return paddle.full(shape=[1], fill_value=value, dtype="bool") + if isinstance(x, (Variable, Value)): + return paddle.full(shape=[], fill_value=value, dtype="bool") else: return value diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index fefb16a8379c4c..c6ed87e365fe34 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from functools import partial, reduce +from functools import cached_property, partial, reduce import paddle from paddle import _C_ops @@ -39,6 +39,7 @@ convert_dtype, in_dygraph_mode, ) +from paddle.framework import use_pir_api from paddle.utils import ( assert_same_structure, copy_mutable_vars, @@ -1169,6 +1170,167 @@ def _check_args(branch_index, branch_fns, default): return final_fn() +class OutputSelector: + def __init__( + self, if_op, flattened_true_output, flattened_false_output, names + ): + self.if_op = if_op + self.true_output = flattened_true_output + self.false_output = flattened_false_output + self.names = names + self.num_output = len(flattened_true_output) + assert len(flattened_false_output) == self.num_output + assert len(names) == self.num_output + + @cached_property + def unified_output(self): + unified_true_output = [] + unified_false_output = [] + variable_indices = [] + for true_out, false_out, name in zip( + self.true_output, self.false_output, self.names + ): + ( + true_out, + false_out, + ) = OutputSelector.constant_to_variable_promotion( + [ + (true_out, self.if_op.true_block), + (false_out, self.if_op.false_block), + ], + name, + ) + if isinstance(true_out, paddle.pir.Value): + assert isinstance( + false_out, paddle.pir.Value + ), "true_out and false_out should be both paddle.pir.Value" + variable_indices.append(len(unified_true_output)) + unified_true_output.append(true_out) + unified_false_output.append(false_out) + return unified_true_output, unified_false_output, variable_indices + + @property + def unified_true_output(self): + return self.unified_output[0] + + @property + def unified_false_output(self): + return self.unified_output[1] + + @property + def variable_indices(self): + return self.unified_output[2] + + @property + def constant_indices(self): + return [ + i + for i in range(len(self.true_output)) + if i not in self.variable_indices + ] + + def get_variable_outputs(self): + variable_true_output = self.select_by_indices( + self.unified_true_output, + self.variable_indices, + ) + variable_false_output = self.select_by_indices( + self.unified_false_output, + self.variable_indices, + ) + return variable_true_output, variable_false_output + + def restore_outputs_by_variable_results(self, variable_results): + constant_output = self.select_by_indices( + self.unified_true_output, + self.constant_indices, + ) + restored_output = [None for _ in range(self.num_output)] + self.fill_to_indices( + restored_output, + variable_results, + self.variable_indices, + ) + self.fill_to_indices( + restored_output, + constant_output, + self.constant_indices, + ) + return restored_output + + @staticmethod + def select_by_indices(unified_args, indices): + return [unified_args[i] for i in indices] + + @staticmethod + def fill_to_indices(outputs, partial_outputs, partial_indices): + for i, out in zip(partial_indices, partial_outputs): + outputs[i] = out + return outputs + + @staticmethod + def constant_to_variable_promotion(out_with_blocks, name): + from paddle.jit.dy2static.variable_trans_func import to_static_variable + + promotion_builtin_types = (bool, int, float) + outs, _ = zip(*out_with_blocks) + + def all_has_same_value(outs): + if len(outs) <= 1: + return True + return all(out == outs[0] for out in outs[1:]) + + def all_has_same_type(outs): + if len(outs) <= 1: + return True + return all(type(out) is type(outs[0]) for out in outs[1:]) + + def constant_to_variable_with_block(constant, block_context_manager): + with block_context_manager(): + return to_static_variable(constant) + + if all(isinstance(out, paddle.pir.Value) for out in outs): + return outs + + if all(arg is None for arg in outs): + return outs + + if all( + isinstance(out, promotion_builtin_types) for out in outs + ) and all_has_same_type(outs): + if all_has_same_value(outs): + return outs + else: + warnings.warn( + f"Return results from different branches in cond has same type: {type(outs[0])}, " + f"but has different value: true value is '{outs[0]}' and false value is '{outs[1]}', " + "so we will promote the constant to variable." + ) + return [ + constant_to_variable_with_block(out, block) + for out, block in out_with_blocks + ] + + if any(isinstance(out, paddle.pir.Value) for out in outs) and all( + isinstance(out, (paddle.pir.Value,) + promotion_builtin_types) + for out in outs + ): + warnings.warn( + "Return results from different branches in cond are not same type: " + f"false_var returned by false_fn is '{type(outs[1])}' and true_var of true_fn is " + f"'{type(outs[0])}'" + ) + return [ + constant_to_variable_with_block(out, block) + for out, block in out_with_blocks + ] + + raise TypeError( + "Unsupported return type of true_fn and false_fn in cond: false_var " + f"returned `{name}` by false_fn is `{outs[0]}` and true_var of true_fn is `{outs[1]}`" + ) + + def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): """ This API returns ``true_fn()`` if the predicate ``pred`` is true else @@ -1440,18 +1602,46 @@ def check_ret_none(seq_true, seq_false, seq_names): _to_sequence_except_dict(return_names), ) - if is_dy2static: + if is_dy2static and not use_pir_api(): true_output, false_output = change_none_to_undefinedvar( true_output, false_output ) if in_pir_mode(): + flattened_true_output, flattened_false_output = flatten( + true_output + ), flatten(false_output) + flattened_return_names = [ + name + for seq_out, name in zip( + _to_sequence_except_dict(true_output), + _to_sequence_except_dict(return_names), + ) + for _ in flatten(seq_out) + ] + output_selector = OutputSelector( + if_op, + flattened_true_output, + flattened_false_output, + names=flattened_return_names, + ) + ( + variable_true_output, + variable_false_output, + ) = output_selector.get_variable_outputs() + with if_op.true_block(): - cf_yield(flatten(true_output)) + cf_yield(variable_true_output) with if_op.false_block(): - cf_yield(flatten(false_output)) + cf_yield(variable_false_output) + if_op.update_output() - return pack_sequence_as(true_output, flatten(if_op.results())) + variable_results = flatten(if_op.results()) + + restored_output = output_selector.restore_outputs_by_variable_results( + variable_results + ) + return pack_sequence_as(true_output, restored_output) mask = paddle.cast(pred, dtype='int32') merge_func = ( diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index ce34cc841b5187..71740db8bfeb0a 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -321,7 +321,7 @@ def _run_dygraph(self, to_static=False): return ret.numpy() # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype - @test_legacy_only + @test_legacy_and_pt_and_pir def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -523,8 +523,8 @@ def get_dy2stat_out(self): return out # Why add test_legacy_only? : PIR not support if true and false branch output with different rank - @test_legacy_only @test_ast_only + @test_legacy_and_pt_and_pir def test_ast_to_func(self): self.setUp() self.assertIsInstance(self.out[0], paddle.Tensor) @@ -533,13 +533,13 @@ def test_ast_to_func(self): class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1): def setUp(self): - self.x = np.random.random([5]).astype('float32') + self.x = np.random.random([5]).astype('int64') self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3) self.out = self.get_dy2stat_out() # Why add test_legacy_only? : PIR not support if true and false branch output with different rank - @test_legacy_only @test_ast_only + @test_legacy_and_pt_and_pir def test_ast_to_func(self): self.setUp() self.assertIsInstance(self.out, paddle.Tensor) @@ -551,6 +551,7 @@ def setUp(self): self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4) @test_ast_only + @test_legacy_and_pt_and_pir def test_ast_to_func(self): with enable_to_static_guard(True): with self.assertRaises(Dygraph2StaticException): diff --git a/test/dygraph_to_static/test_origin_info.py b/test/dygraph_to_static/test_origin_info.py index 002eee308fa1e6..24871ab6c1d468 100644 --- a/test/dygraph_to_static/test_origin_info.py +++ b/test/dygraph_to_static/test_origin_info.py @@ -70,7 +70,7 @@ def set_test_func(self): self.func = simple_func def set_static_lineno(self): - self.static_abs_lineno_list = [9, 11, 12] + self.static_abs_lineno_list = [9, 12, 13] def set_dygraph_info(self): self.line_num = 3 @@ -158,7 +158,7 @@ def set_test_func(self): self.func = nested_func def set_static_lineno(self): - self.static_abs_lineno_list = [9, 12, 14, 16, 17] + self.static_abs_lineno_list = [9, 13, 16, 18, 19] def set_dygraph_info(self): self.line_num = 5 @@ -187,7 +187,7 @@ def set_test_func(self): self.func = decorated_func def set_static_lineno(self): - self.static_abs_lineno_list = [9, 11] + self.static_abs_lineno_list = [9, 12] def set_dygraph_info(self): self.line_num = 2 @@ -209,7 +209,7 @@ def set_test_func(self): self.func = decorated_func2 def set_static_lineno(self): - self.static_abs_lineno_list = [9, 11] + self.static_abs_lineno_list = [9, 12] def set_dygraph_info(self): self.line_num = 2