diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc index 624ce6221cd5e7..d7ad210102b94b 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc @@ -215,8 +215,36 @@ void IfInstruction::CopyBranchOutput(const std::vector& var_names, } void IfInstruction::Run() { - DeviceContext().Wait(); - if (cond_var_->Get().data()[0]) { + bool cond = true; + if (cond_var_->IsType()) { + auto& cond_tensor = cond_var_->Get(); + if (paddle::platform::is_cpu_place(cond_tensor.place())) { + cond = cond_tensor.data()[0]; + } else { + // when platform::is_gpu_place(cond.place()) or + // platform::is_xpu_place(cond.place()) is true +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_CUSTOM_DEVICE) + DeviceContext().Wait(); + phi::DenseTensor cpu_cond; + paddle::framework::TensorCopySync( + cond_tensor, platform::CPUPlace(), &cpu_cond); + cond = cpu_cond.data()[0]; +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "This version of PaddlePaddle does NOT support GPU/XPU but got " + "GPU/XPU tensor Cond in WhileOp. Please compile WITH_GPU or " + "WITH_XPU option.")); +#endif + } + } else if (cond_var_->IsType()) { + auto& cond_array = cond_var_->Get(); + cond = std::all_of( + cond_array.begin(), cond_array.end(), [](const Variable* t) { + return t->Get().numel() != 0; + }); + } + if (cond) { true_branch_inter_->Run({}, false); CopyBranchOutput(true_branch_outputs_, true_branch_inter_); } else { diff --git a/paddle/fluid/pir/dialect/operator/ir/api_builder.h b/paddle/fluid/pir/dialect/operator/ir/api_builder.h index aa20ef34e17a64..70f048a0acf10f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/api_builder.h +++ b/paddle/fluid/pir/dialect/operator/ir/api_builder.h @@ -43,7 +43,7 @@ class ApiBuilder { void SetParameter(const std::string& name, std::unique_ptr&& parameter); - std::shared_ptr GetBuilder() { return builder_; } + const std::shared_ptr& GetBuilder() const { return builder_; } const pir::InsertionPoint& GetCurrentInsertionPoint() const { return builder_->insertion_point(); diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 30d5ce5a1b685e..9e3f283c95ec16 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -217,8 +217,14 @@ void IfOp::VerifyRegion() { 1u, phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", (*this)->region(0).size())); - if ((*this)->region(0).front().size() > 0) { - auto &true_last_op = (*this)->region(0).front().back(); + if ((*this)->num_results() != 0) { + auto &true_block = (*this)->region(0).front(); + PADDLE_ENFORCE_GT( + true_block.size(), + 0u, + phi::errors::PreconditionNotMet( + "The true block must have at least one op yield op.")); + auto &true_last_op = true_block.back(); PADDLE_ENFORCE_EQ(true, true_last_op.isa(), phi::errors::PreconditionNotMet( @@ -228,15 +234,19 @@ void IfOp::VerifyRegion() { phi::errors::PreconditionNotMet( "The size of last of true block op's input must be " "equal to IfOp's outputs num.")); - } - VLOG(4) << "Start Verifying false branch."; - PADDLE_ENFORCE_EQ( - (*this)->region(1).size(), - 1u, - phi::errors::PreconditionNotMet("The size %d of false_region must be 1.", - (*this)->region(0).size())); - if ((*this)->region(1).front().size() > 0) { - auto &false_last_op = (*this)->region(1).front().back(); + VLOG(4) << "Start Verifying false branch."; + PADDLE_ENFORCE_EQ((*this)->region(1).size(), + 1u, + phi::errors::PreconditionNotMet( + "The size %d of false_region must be 1.", + (*this)->region(0).size())); + auto &false_block = (*this)->region(1).front(); + PADDLE_ENFORCE_GT( + false_block.size(), + 0u, + phi::errors::PreconditionNotMet( + "The false block must have at least one op yield op.")); + auto &false_last_op = false_block.back(); PADDLE_ENFORCE_EQ(true, false_last_op.isa(), phi::errors::PreconditionNotMet( diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index d6eced4fb66a9e..82c02b04e7a394 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -1124,27 +1124,6 @@ void HandleForIfOp( "[%d]'s input of [%s] op MUST in map pair", 0, op_item->name())); auto new_cond = map_value_pair->at(old_cond); - // NOTE(zhangbo): IfOp's input cond should be a cpu type. - AllocatedDenseTensorType new_cond_type = - new_cond.type().dyn_cast(); - if (new_cond_type) { - if (new_cond_type.place().GetType() == phi::AllocationType::GPU) { - auto out_type = AllocatedDenseTensorType::get( - ctx, phi::CPUPlace(), old_cond.type().dyn_cast()); - phi::KernelKey kernel_key( - phi::Backend::GPU, phi::DataLayout::ALL_LAYOUT, phi::DataType::BOOL); - new_cond = AddPlaceTransferOp(new_cond, - out_type, - new_cond_type.place(), - phi::CPUPlace(), - kernel_key, - block); - } - } else { - PADDLE_THROW( - phi::errors::Unimplemented("IfOp onlu support DenseTensorType")); - } - // Create IfOp and insert to kernel dialect program pir::Builder builder(ctx, block); auto old_ifop = op_item->dyn_cast(); diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index 42beed478d8219..7533aac122a138 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -44,6 +44,7 @@ using paddle::pybind::PyIfOp; using paddle::pybind::PyWhileOp; using pir::Block; using pir::Builder; +using pir::CombineOp; using pir::Operation; using pir::Program; using pir::Region; @@ -60,6 +61,11 @@ void BindIfOp(py::module* m) { return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build( cond, std::vector{})); }); + m->def("build_if_op", [](const std::vector& cond) { + auto& builder = ApiBuilder::Instance().GetBuilder(); + auto new_cond = builder->Build(cond).out(); + return PyIfOp(builder->Build(new_cond, std::vector{})); + }); py::class_ if_op(*m, "IfOp", R"DOC( The PyIfOp is a encapsulation of IfOp. Compared with ifOp, it provides an additional 'update_output' interface. The 'update_output' interface will construct a new IfOp operation to replace its underlying IfOp. In the process, the original @@ -67,6 +73,7 @@ void BindIfOp(py::module* m) { )DOC"); if_op.def("true_block", &PyIfOp::true_block, return_value_policy::reference) .def("false_block", &PyIfOp::false_block, return_value_policy::reference) + .def("cond", &PyIfOp::cond) .def("update_output", &PyIfOp::UpdateOutput) .def("as_operation", &PyIfOp::operation, return_value_policy::reference) .def("results", [](PyIfOp& self) -> py::list { diff --git a/python/paddle/base/layer_helper.py b/python/paddle/base/layer_helper.py index 8f4b068d4e8978..add33e0a4edeed 100644 --- a/python/paddle/base/layer_helper.py +++ b/python/paddle/base/layer_helper.py @@ -22,6 +22,7 @@ Parameter, dtype_is_floating, in_dygraph_mode, + in_pir_mode, ) from .layer_helper_base import LayerHelperBase from .param_attr import ParamAttr @@ -132,6 +133,8 @@ def append_bias_op(self, input_var, dim_start=1, dim_end=None): b = self.create_parameter( attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True ) + if in_pir_mode(): + return input_var + b tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) self.append_op( type='elementwise_add', diff --git a/python/paddle/base/layer_helper_base.py b/python/paddle/base/layer_helper_base.py index 003aef14655bbc..197782813ad608 100644 --- a/python/paddle/base/layer_helper_base.py +++ b/python/paddle/base/layer_helper_base.py @@ -19,6 +19,7 @@ import paddle from . import core, unique_name +from .data_feeder import convert_dtype from .framework import ( Variable, _current_expected_place, @@ -359,6 +360,8 @@ def create_parameter( # set global dtype if not dtype: dtype = self.__dtype + if isinstance(dtype, core.DataType): + dtype = convert_dtype(dtype) if is_bias: suffix = 'b' default_initializer = ( diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 65f21a169cb63b..88cd6ab9e0b5a8 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -27,6 +27,7 @@ default_main_program, in_dygraph_mode, in_dynamic_or_pir_mode, + in_pir_mode, name_scope, program_guard, static_only, @@ -191,10 +192,17 @@ def fc_base( name=None, ): helper = LayerHelper("fc", **locals()) - check_type(input, 'input', (list, tuple, Variable), 'fc') + check_type( + input, 'input', (list, tuple, Variable, paddle.pir.Value), 'fc' + ) if isinstance(input, (list, tuple)): for i, input_x in enumerate(input): - check_type(input_x, 'input[' + str(i) + ']', Variable, 'fc') + check_type( + input_x, + 'input[' + str(i) + ']', + (Variable, paddle.pir.Value), + 'fc', + ) dtype = helper.input_dtype() check_dtype( dtype, 'input', ['float16', 'uint16', 'float32', 'float64'], 'fc' @@ -210,17 +218,25 @@ def fc_base( w = helper.create_parameter( attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False ) - tmp = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type="mul", - inputs={"X": input_var, "Y": w}, - outputs={"Out": tmp}, - attrs={"x_num_col_dims": num_flatten_dims, "y_num_col_dims": 1}, - ) + if in_pir_mode(): + tmp = paddle.matmul(input_var, w) + else: + tmp = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="mul", + inputs={"X": input_var, "Y": w}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": num_flatten_dims, + "y_num_col_dims": 1, + }, + ) mul_results.append(tmp) if len(mul_results) == 1: pre_bias = mul_results[0] + elif in_pir_mode(): + pre_bias = paddle.add_n(mul_results) else: pre_bias = helper.create_variable_for_type_inference(dtype) helper.append_op( diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 1302a808cecc2b..fefb16a8379c4c 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -175,6 +175,51 @@ def __exit__(self, exc_type, exc_val, exc_tb): return super().__exit__(exc_type, exc_val, exc_tb) +class If: + ''' + **If** + + If is an operator that bind two blocks (true_block and false_block) to a specific condition, + According to the condition, the corresponding block will be executed. + + Args: + cond (Value): A value whose data type is bool controlling which block is executed. + + Examples: + .. code-block:: python + + >>> import paddle + >>> from paddle.static.nn.control_flow import ConditionalBlock + + >>> label = paddle.rand([1]) + >>> limit = paddle.ones([1]) * 0.5 + >>> cond = paddle.less_than(x=label, y=limit) + >>> if_op = If(cond) + >>> with if_op.true_block(): + ... pass + >>> with if_op.false_block(): + ... pass + ''' + + def __init__(self, cond): + if not isinstance(cond, list): + check_variable_and_dtype(cond, 'cond', ['bool'], 'static.nn.If') + if reduce(lambda a, b: a * b, cond.shape, 1) != 1: + raise TypeError( + "condition expected shape as [1], but given shape as {}.".format( + list(cond.shape) + ) + ) + self.if_op = build_if_op(cond) + self.cond_var = self.if_op.cond() + + def true_block(self): + return self.if_op.true_block() + + def false_block(self): + return self.if_op.false_block() + + class ConditionalBlock: ''' **ConditionalBlock** @@ -208,13 +253,23 @@ class ConditionalBlock: ''' def __init__(self, inputs, is_scalar_condition=False, name=None): - for each_input in inputs: - check_type(each_input, "input", Variable, "ConditionalBlock") self.inputs = inputs + if in_pir_mode(): + if is_scalar_condition and len(inputs) != 1: + raise TypeError( + "For ConditionalBlock Api, Only support one input while is_scalar_condition is True" + ) + return + else: + for each_input in inputs: + check_type(each_input, "input", Variable, "ConditionalBlock") + self.is_scalar_condition = is_scalar_condition self.helper = LayerHelper('conditional_block', name=name) def block(self): + if in_pir_mode(): + return If(self.inputs).true_block() return ConditionalBlockGuard(self) def complete(self): @@ -1244,9 +1299,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): return None true_output = None false_output = None + check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond") + check_type(name, "name", (str, type(None)), "base.layers.cond") if in_pir_mode(): - check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond") - check_type(name, "name", (str, type(None)), "base.layers.cond") if_op = build_if_op(pred) if true_fn is not None: if not callable(true_fn): @@ -1267,8 +1322,6 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): with if_op.false_block(): false_output = false_fn() else: - check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond") - check_type(name, "name", (str, type(None)), "base.layers.cond") helper = LayerHelper('cond', **locals()) copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper) if true_fn is not None: diff --git a/test/dygraph_to_static/test_loop.py b/test/dygraph_to_static/test_loop.py index 8414c488aba23c..fb2600b8ac2dc0 100644 --- a/test/dygraph_to_static/test_loop.py +++ b/test/dygraph_to_static/test_loop.py @@ -19,7 +19,6 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - test_legacy_and_pt_and_pir, ) import paddle @@ -252,7 +251,6 @@ def setUp(self): self.nested_for_loop_func = nested_for_loop_dyfunc - @test_legacy_and_pt_and_pir def test_loop_vars(self): for i in range(len(self.loop_funcs)): func = self.loop_funcs[i] @@ -268,7 +266,6 @@ def test_loop_vars(self): self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(create_var_names, self.create_var_names[i]) - @test_legacy_and_pt_and_pir def test_nested_loop_vars(self): func = self.nested_for_loop_func test_func = inspect.getsource(func) diff --git a/test/legacy_test/test_conditional_block.py b/test/legacy_test/test_conditional_block.py index 90a8200375c65a..b5f5df9205ae8a 100644 --- a/test/legacy_test/test_conditional_block.py +++ b/test/legacy_test/test_conditional_block.py @@ -31,7 +31,10 @@ def test_forward(self): data = paddle.static.data(name='X', shape=[-1, 1], dtype='float32') data.stop_gradient = False cond = ConditionalBlock(inputs=[data]) - out = paddle.tensor.create_tensor(dtype='float32') + out = paddle.tensor.fill_constant( + [10, 10], dtype='float32', value=0.0 + ) + out.stop_gradient = False with cond.block(): hidden = paddle.static.nn.fc(x=data, size=10) paddle.assign(hidden, out) @@ -43,7 +46,6 @@ def test_forward(self): x = np.random.random(size=(10, 1)).astype('float32') outs = exe.run(main_program, feed={'X': x}, fetch_list=[out])[0] - print(outs) loss = paddle.mean(out) append_backward(loss=loss) outs = exe.run( @@ -51,7 +53,6 @@ def test_forward(self): feed={'X': x}, fetch_list=[main_program.block(0).var(data.name + "@GRAD")], )[0] - print(outs) class TestConditionalBlockOpInferShape(unittest.TestCase):