Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,13 @@ def transform(self):
self.visit(self.root)

def visit_Break(self, node):
function_def_node_index = _find_ancestor_function_def_index(
self.ancestor_nodes
)
loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
function_def_node = self.ancestor_nodes[function_def_node_index]

# 1. Map the 'break/continue' stmt with an unique boolean variable V.
variable_name = unique_name.generate(BREAK_NAME_PREFIX)
Expand All @@ -140,7 +144,7 @@ def visit_Break(self, node):

# 4. For 'break' add break into condition of the loop.
assign_false_node = create_bool_node(variable_name, False)
self._add_stmt_before_cur_node(loop_node_index, assign_false_node)
function_def_node.body.insert(0, assign_false_node)

cond_var_node = gast.UnaryOp(
op=gast.Not(),
Expand All @@ -164,9 +168,13 @@ def visit_Break(self, node):
for_to_while.transform()

def visit_Continue(self, node):
function_def_node_index = _find_ancestor_function_def_index(
self.ancestor_nodes
)
loop_node_index = _find_ancestor_loop_index(node, self.ancestor_nodes)
assert loop_node_index != -1, "SyntaxError: 'continue' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
function_def_node = self.ancestor_nodes[function_def_node_index]

# 1. Map the 'break/continue' stmt with an unique boolean variable V.
variable_name = unique_name.generate(CONTINUE_NAME_PREFIX)
Expand All @@ -185,6 +193,9 @@ def visit_Continue(self, node):
# 4. For 'continue', set continue to False at the beginning of each loop
assign_false_node = create_bool_node(variable_name, False)
loop_node.body.insert(0, assign_false_node)
# Add a same assign statement to the beginning of function body to avoid
# generate the UndefinedVar
function_def_node.body.insert(0, assign_false_node)

def _remove_stmts_after_break_continue(
self, break_continue_node, break_continue_name, loop_node_index
Expand Down Expand Up @@ -298,6 +309,13 @@ def _find_ancestor_loop_index(node, ancestor_nodes):
return -1


def _find_ancestor_function_def_index(ancestor_nodes):
for i in range(len(ancestor_nodes) - 1, -1, -1):
if isinstance(ancestor_nodes[i], gast.FunctionDef):
return i
return -1


class BreakTransformOptimizer(BaseNodeVisitor):
"""
In specific pattern, the transformed code could be optimized by joining the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from paddle.base import unique_name
from paddle.jit.dy2static.variable_trans_func import create_bool_node
from paddle.utils import gast

from ..utils import (
Expand Down Expand Up @@ -232,6 +233,7 @@ def transform(self):
return node

# Prepend initialization of final return and append final return statement
return_flag_names = self.return_name
value_name = self.return_value_name
if value_name is not None:
node.body.append(
Expand All @@ -257,6 +259,10 @@ def transform(self):
)
node.body.insert(0, assign_return_value_node)

for return_flag_name in return_flag_names:
assign_return_flag_node = create_bool_node(return_flag_name, False)
node.body.insert(0, assign_return_flag_node)

# Prepend no value placeholders
return node

Expand Down
8 changes: 5 additions & 3 deletions python/paddle/jit/dy2static/variable_trans_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import paddle
from paddle.base.framework import Variable
from paddle.framework import use_pir_api
from paddle.pir import Value
from paddle.utils import gast, is_sequence, map_structure

from .utils import UndefinedVar, create_undefined_variable
Expand All @@ -36,7 +38,7 @@ def to_static_variable(x):
return paddle.full(shape=[], dtype='float64', fill_value=x)
if isinstance(x, int):
return paddle.full(shape=[], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar) or x is None:
if not use_pir_api() and (isinstance(x, UndefinedVar) or x is None):
"""
for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
Expand All @@ -50,8 +52,8 @@ def create_bool_as_type(x, value=True):
'''
Create a bool variable, which type is the same as x.
'''
if isinstance(x, Variable):
return paddle.full(shape=[1], fill_value=value, dtype="bool")
if isinstance(x, (Variable, Value)):
return paddle.full(shape=[], fill_value=value, dtype="bool")
else:
return value

Expand Down
200 changes: 195 additions & 5 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import warnings
from functools import partial, reduce
from functools import cached_property, partial, reduce

import paddle
from paddle import _C_ops
Expand All @@ -39,6 +39,7 @@
convert_dtype,
in_dygraph_mode,
)
from paddle.framework import use_pir_api
from paddle.utils import (
assert_same_structure,
copy_mutable_vars,
Expand Down Expand Up @@ -1169,6 +1170,167 @@ def _check_args(branch_index, branch_fns, default):
return final_fn()


class OutputSelector:
def __init__(
self, if_op, flattened_true_output, flattened_false_output, names
):
self.if_op = if_op
self.true_output = flattened_true_output
self.false_output = flattened_false_output
self.names = names
self.num_output = len(flattened_true_output)
assert len(flattened_false_output) == self.num_output
assert len(names) == self.num_output

@cached_property
def unified_output(self):
unified_true_output = []
unified_false_output = []
variable_indices = []
for true_out, false_out, name in zip(
self.true_output, self.false_output, self.names
):
(
true_out,
false_out,
) = OutputSelector.constant_to_variable_promotion(
[
(true_out, self.if_op.true_block),
(false_out, self.if_op.false_block),
],
name,
)
if isinstance(true_out, paddle.pir.Value):
assert isinstance(
false_out, paddle.pir.Value
), "true_out and false_out should be both paddle.pir.Value"
variable_indices.append(len(unified_true_output))
unified_true_output.append(true_out)
unified_false_output.append(false_out)
return unified_true_output, unified_false_output, variable_indices

@property
def unified_true_output(self):
return self.unified_output[0]

@property
def unified_false_output(self):
return self.unified_output[1]

@property
def variable_indices(self):
return self.unified_output[2]

@property
def constant_indices(self):
return [
i
for i in range(len(self.true_output))
if i not in self.variable_indices
]

def get_variable_outputs(self):
variable_true_output = self.select_by_indices(
self.unified_true_output,
self.variable_indices,
)
variable_false_output = self.select_by_indices(
self.unified_false_output,
self.variable_indices,
)
return variable_true_output, variable_false_output

def restore_outputs_by_variable_results(self, variable_results):
constant_output = self.select_by_indices(
self.unified_true_output,
self.constant_indices,
)
restored_output = [None for _ in range(self.num_output)]
self.fill_to_indices(
restored_output,
variable_results,
self.variable_indices,
)
self.fill_to_indices(
restored_output,
constant_output,
self.constant_indices,
)
return restored_output

@staticmethod
def select_by_indices(unified_args, indices):
return [unified_args[i] for i in indices]

@staticmethod
def fill_to_indices(outputs, partial_outputs, partial_indices):
for i, out in zip(partial_indices, partial_outputs):
outputs[i] = out
return outputs

@staticmethod
def constant_to_variable_promotion(out_with_blocks, name):
from paddle.jit.dy2static.variable_trans_func import to_static_variable

promotion_builtin_types = (bool, int, float)
outs, _ = zip(*out_with_blocks)

def all_has_same_value(outs):
if len(outs) <= 1:
return True
return all(out == outs[0] for out in outs[1:])

def all_has_same_type(outs):
if len(outs) <= 1:
return True
return all(type(out) is type(outs[0]) for out in outs[1:])

def constant_to_variable_with_block(constant, block_context_manager):
with block_context_manager():
return to_static_variable(constant)

if all(isinstance(out, paddle.pir.Value) for out in outs):
return outs

if all(arg is None for arg in outs):
return outs

if all(
isinstance(out, promotion_builtin_types) for out in outs
) and all_has_same_type(outs):
if all_has_same_value(outs):
return outs
else:
warnings.warn(
f"Return results from different branches in cond has same type: {type(outs[0])}, "
f"but has different value: true value is '{outs[0]}' and false value is '{outs[1]}', "
"so we will promote the constant to variable."
)
return [
constant_to_variable_with_block(out, block)
for out, block in out_with_blocks
]

if any(isinstance(out, paddle.pir.Value) for out in outs) and all(
isinstance(out, (paddle.pir.Value,) + promotion_builtin_types)
for out in outs
):
warnings.warn(
"Return results from different branches in cond are not same type: "
f"false_var returned by false_fn is '{type(outs[1])}' and true_var of true_fn is "
f"'{type(outs[0])}'"
)
return [
constant_to_variable_with_block(out, block)
for out, block in out_with_blocks
]

raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
f"returned `{name}` by false_fn is `{outs[0]}` and true_var of true_fn is `{outs[1]}`"
)


def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
"""
This API returns ``true_fn()`` if the predicate ``pred`` is true else
Expand Down Expand Up @@ -1440,18 +1602,46 @@ def check_ret_none(seq_true, seq_false, seq_names):
_to_sequence_except_dict(return_names),
)

if is_dy2static:
if is_dy2static and not use_pir_api():
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output
)

if in_pir_mode():
flattened_true_output, flattened_false_output = flatten(
true_output
), flatten(false_output)
flattened_return_names = [
name
for seq_out, name in zip(
_to_sequence_except_dict(true_output),
_to_sequence_except_dict(return_names),
)
for _ in flatten(seq_out)
]
output_selector = OutputSelector(
if_op,
flattened_true_output,
flattened_false_output,
names=flattened_return_names,
)
(
variable_true_output,
variable_false_output,
) = output_selector.get_variable_outputs()

with if_op.true_block():
cf_yield(flatten(true_output))
cf_yield(variable_true_output)
with if_op.false_block():
cf_yield(flatten(false_output))
cf_yield(variable_false_output)

if_op.update_output()
return pack_sequence_as(true_output, flatten(if_op.results()))
variable_results = flatten(if_op.results())

restored_output = output_selector.restore_outputs_by_variable_results(
variable_results
)
return pack_sequence_as(true_output, restored_output)

mask = paddle.cast(pred, dtype='int32')
merge_func = (
Expand Down
9 changes: 5 additions & 4 deletions test/dygraph_to_static/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _run_dygraph(self, to_static=False):
return ret.numpy()

# Why add test_legacy_only? : PIR not support if true and false branch output with different dtype
@test_legacy_only
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())

Expand Down Expand Up @@ -523,8 +523,8 @@ def get_dy2stat_out(self):
return out

# Why add test_legacy_only? : PIR not support if true and false branch output with different rank
@test_legacy_only
@test_ast_only
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out[0], paddle.Tensor)
Expand All @@ -533,13 +533,13 @@ def test_ast_to_func(self):

class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.x = np.random.random([5]).astype('int64')
self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3)
self.out = self.get_dy2stat_out()

# Why add test_legacy_only? : PIR not support if true and false branch output with different rank
@test_legacy_only
@test_ast_only
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
self.setUp()
self.assertIsInstance(self.out, paddle.Tensor)
Expand All @@ -551,6 +551,7 @@ def setUp(self):
self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4)

@test_ast_only
@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
with enable_to_static_guard(True):
with self.assertRaises(Dygraph2StaticException):
Expand Down
Loading