From 6d1d3b342dff1169832c0ed4120b37d1c923d63d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 5 Dec 2023 13:13:37 +0000 Subject: [PATCH 01/19] refine --- .../translator/program_translator.cc | 162 +++++------------- .../translator/program_translator.h | 23 +-- 2 files changed, 56 insertions(+), 129 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index ca3d33f29c0510..ebd2fadf12c42b 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -389,15 +389,11 @@ void ProgramTranslator::Translate() { } } -void ProgramTranslator::TranslateBlock( - const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - TranslationContext* translation_ctx, - pir::Block* dst_block, - bool for_cond_block, - const std::vector& cond_sub_block_outputs, - const std::vector<::paddle::framework::OpDesc*>& cond_init_ops) { +void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + TranslationContext* translation_ctx, + pir::Block* dst_block) { VLOG(8) << "=============>start to translate a block"; PADDLE_ENFORCE( (src_block.OpSize() >= end_id) && (start_id <= end_id), @@ -408,13 +404,8 @@ void ProgramTranslator::TranslateBlock( end_id, src_block.OpSize())); - std::unordered_map translate_completed; std::map assign_output_2_input; for (uint64_t op_id = start_id; op_id < end_id; op_id++) { - if (translate_completed.count(op_id) && translate_completed.at(op_id)) { - continue; - } - auto op = src_block.Op(static_cast(op_id)); VLOG(8) << "=============>start to translate a op: " << op->Type(); @@ -424,144 +415,85 @@ void ProgramTranslator::TranslateBlock( "Not support translated %s op", op->Type())); if (op->Type() == "conditional_block") { - std::vector cond_op_ids = GetCondOpIds(src_block, op_id); - ConditionBlockCombination cond_op_combination(src_block, cond_op_ids); - pir::Operation* if_op = TranslateCondIfOperation( - cond_op_combination, translation_ctx, dst_block); - for (auto cond_id : cond_op_ids) { - translate_completed[cond_id] = true; - } - VLOG(10) << "[op translated][conditional_block]" << if_op; + TranslateIfOperation(op, translation_ctx, dst_block); } else if (op->Type() == "while") { TranslateWhileOperation(op, translation_ctx, dst_block); } else { - if (for_cond_block && op->Type() == "assign" && - std::count(cond_sub_block_outputs.begin(), - cond_sub_block_outputs.end(), - op->Output("Out")[0])) { - assign_output_2_input[op->Output("Out")[0]] = op->Input("X")[0]; - translate_completed[op_id] = true; - } else { - TranslateGeneralOperation(op, translation_ctx, dst_block); - translate_completed[op_id] = true; - } + TranslateGeneralOperation(op, translation_ctx, dst_block); } } - - // NOTE(zhangbo): If conditional_block operator has output, the cf.yeild - // operator needs to be inserted - if (for_cond_block) { - // insert init ops - for (::paddle::framework::OpDesc* init_op : cond_init_ops) { - TranslateGeneralOperation(init_op, translation_ctx, dst_block); - } - // insert yeild op - std::vector yeild_inputs; - for (auto output_name : cond_sub_block_outputs) { - if (assign_output_2_input.count(output_name) != 0) { - if (translation_ctx->count(assign_output_2_input[output_name]) == 0) { - CreateUndefinedVariable(assign_output_2_input[output_name], - src_block); - } - yeild_inputs.emplace_back( - (*translation_ctx)[assign_output_2_input[output_name]].value); - } else { - if (translation_ctx->count(output_name) == 0) { - CreateUndefinedVariable(output_name, src_block); - } - yeild_inputs.emplace_back((*translation_ctx)[output_name].value); - } - } - pir::AttributeMap attribute_map; - auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); - pir::Operation* yeild_op = - pir::Operation::Create(yeild_inputs, attribute_map, {}, yeild_info); - dst_block->push_back(yeild_op); - } } -pir::Operation* ProgramTranslator::TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, +// NOTE(zhangbo): All condition_block_op will be translated as an if_op with +// only a true branch. +void ProgramTranslator::TranslateIfOperation( + const OpDesc* op, TranslationContext* translation_ctx, pir::Block* dst_block) { + VLOG(8) << "=============>Start to translate if op:" << op; auto& type_translator = TypeTranslator::instance(); - auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); - std::vector op_inputs = { - (*translation_ctx)[cond_ops.CondVarName()].value}; - auto input_names = cond_ops.GetInputNamesForIfOp(); - for (auto input_name : input_names) { + auto cond_op_cond = op->Input("Cond")[0]; + auto& cond_op_inputs = op->Input("Input"); + for (auto input_name : cond_op_inputs) { VLOG(6) << "[general op][conditional_block][inputs: " << input_name << "]"; GetValueOrCreateInTop(input_name, translation_ctx); } + auto& cond_op_outputs = op->Output("Out"); + std::vector<::paddle::framework::VarDesc*> cond_op_output_vars; + for (auto out_name : cond_op_outputs) { + cond_op_output_vars.emplace_back(op->Block()->FindVarRecursive(out_name)); + } - // NOTE(zhangbo): Now paddle::dialect::IfOp has 0 attribute + std::vector if_op_inputs = { + (*translation_ctx)[cond_op_cond].value}; pir::AttributeMap attribute_map; - - std::vector op_output_types; - std::vector<::paddle::framework::VarDesc*> output_vardescs = - std::get<0>(cond_ops.CondOutputVars()); - for (auto var_desc : output_vardescs) { + std::vector if_op_output_types; + for (auto var_desc : cond_op_output_vars) { IR_ENFORCE(var_desc != nullptr, "[control flow] Output should not be null"); pir::Type translated_var_type = type_translator[var_desc->GetType()](ctx_, *var_desc); - op_output_types.emplace_back(translated_var_type); + if_op_output_types.emplace_back(translated_var_type); } - VLOG(4) << "[general op][conditional_block] IfOp preparation end."; - + auto if_op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); pir::Operation* operation = pir::Operation::Create( - op_inputs, attribute_map, op_output_types, op_info, 2); - - for (size_t i = 0; i < output_vardescs.size(); i++) { - translation_ctx->PushValue(output_vardescs[i]->Name(), - VariableDefiningInfo(operation->result(i))); - VLOG(4) << "[general op][conditional_block] var " - << output_vardescs[i]->Name() << " was mapped to If's " << i - << "-th output."; - } + if_op_inputs, attribute_map, if_op_output_types, if_op_info, 1); dst_block->push_back(operation); VLOG(4) << "[general op][conditional_block] IfOp creation end."; - if (cond_ops.TrueBlockId() != -1) { - const BlockDesc& true_sub_block = - legacy_program_->Block(cond_ops.TrueBlockId()); + if (op->GetBlockAttrId("sub_block") != -1) { + auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); pir::Region& true_region = operation->region(0); if (true_region.empty()) true_region.emplace_back(); - auto* true_block_context = translation_ctx->CreateInnerContext(); - - TranslateBlock(true_sub_block, + TranslateBlock(sub_block, 0, - true_sub_block.OpSize(), + sub_block.OpSize(), true_block_context, - &true_region.front(), - true, - std::get<1>(cond_ops.CondOutputVars()), - cond_ops.TrueBlockInitOps()); + &true_region.front()); + // insert yeild op to true block + auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); + std::vector yeild_inputs{ + true_block_context->at(cond_op_cond).value}; + for (auto& out_name : cond_op_outputs) { + yeild_inputs.push_back(true_block_context->at(out_name).value); + } + true_region.front().push_back( + pir::Operation::Create(yeild_inputs, {}, {}, yeild_info)); } VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; - if (cond_ops.FalseBlockId() != -1) { - const BlockDesc& false_sub_block = - legacy_program_->Block(cond_ops.FalseBlockId()); - pir::Region& false_region = operation->region(1); - if (false_region.empty()) false_region.emplace_back(); - auto* false_block_context = translation_ctx->CreateInnerContext(); - TranslateBlock(false_sub_block, - 0, - false_sub_block.OpSize(), - false_block_context, - &false_region.front(), - true, - std::get<2>(cond_ops.CondOutputVars()), - cond_ops.FalseBlockInitOps()); + for (size_t i = 0; i < cond_op_output_vars.size(); i++) { + translation_ctx->PushValue(cond_op_output_vars[i]->Name(), + VariableDefiningInfo(operation->result(i))); + VLOG(4) << "[general op][conditional_block] var " + << cond_op_output_vars[i]->Name() << " was mapped to If's " << i + << "-th output."; } - VLOG(4) << "[general op][conditional_block] IfOp false block translate end."; operation->Verify(); VLOG(4) << "[general op][conditional_block] IfOp translate end."; - return operation; } void ProgramTranslator::TranslateWhileOperation( diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 0dda3dc9b89219..a0e01bea1caf0c 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -145,15 +145,11 @@ class ProgramTranslator { static const std::unordered_set unsupported_ops; - void TranslateBlock( - const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - TranslationContext* translation_ctx, - pir::Block* dst_block, - bool for_cond_block = false, - const std::vector& cond_sub_block_outputs = {}, - const std::vector<::paddle::framework::OpDesc*>& cond_init_ops = {}); + void TranslateBlock(const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + TranslationContext* translation_ctx, + pir::Block* dst_block); void TranslateGeneralOperation(const OpDesc* src_op, TranslationContext* translation_ctx, @@ -169,11 +165,10 @@ class ProgramTranslator { const VariableDefiningInfo& CreateUndefinedVariable( const std::string& var_name, const BlockDesc& block); - /// Translate methods for control flow ops. - pir::Operation* TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, - TranslationContext* translation_ctx, - pir::Block* dst_block); + void TranslateIfOperation(const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dst_block); + void TranslateWhileOperation(const OpDesc* op, TranslationContext* translation_ctx, pir::Block* dst_block); From ed4eb36fa0562751be8babf653dc18c279595961 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 5 Dec 2023 13:29:14 +0000 Subject: [PATCH 02/19] fix --- .../dialect/operator/ir/control_flow_op.cc | 69 ++++++++++--------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 16070fac12ae8e..e0d39868579e87 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -178,11 +178,11 @@ void IfOp::VerifySig() { "bool DenseTensorType.")); } - PADDLE_ENFORCE_EQ((*this)->num_regions(), - 2u, - phi::errors::PreconditionNotMet( - "The size %d of regions must be equal to 2.", - (*this)->num_regions())); + // PADDLE_ENFORCE_EQ((*this)->num_regions(), + // 2u, + // phi::errors::PreconditionNotMet( + // "The size %d of regions must be equal to 2.", + // (*this)->num_regions())); } void IfOp::VerifyRegion() { @@ -193,36 +193,37 @@ void IfOp::VerifyRegion() { phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", (*this)->region(0).size())); - if ((*this)->num_results() != 0) { - PADDLE_ENFORCE_EQ( - (*this)->region(0).size(), - (*this)->region(1).size(), - phi::errors::PreconditionNotMet("The size %d of true_region must be " - "equal to the size %d of false_region.", - (*this)->region(0).size(), - (*this)->region(1).size())); + // if ((*this)->num_results() != 0) { + // PADDLE_ENFORCE_EQ( + // (*this)->region(0).size(), + // (*this)->region(1).size(), + // phi::errors::PreconditionNotMet("The size %d of true_region must be " + // "equal to the size %d of + // false_region.", + // (*this)->region(0).size(), + // (*this)->region(1).size())); - auto &true_last_op = (*this)->region(0).front().back(); - auto &false_last_op = (*this)->region(1).front().back(); - PADDLE_ENFORCE_EQ(true, - true_last_op.isa(), - phi::errors::PreconditionNotMet( - "The last of true block must be YieldOp")); - PADDLE_ENFORCE_EQ(true_last_op.num_operands(), - (*this)->num_results(), - phi::errors::PreconditionNotMet( - "The size of last of true block op's input must be " - "equal to IfOp's outputs num.")); - PADDLE_ENFORCE_EQ(true, - false_last_op.isa(), - phi::errors::PreconditionNotMet( - "The last of false block must be YieldOp")); - PADDLE_ENFORCE_EQ(false_last_op.num_operands(), - (*this)->num_results(), - phi::errors::PreconditionNotMet( - "The size of last of false block op's input must be " - "equal to IfOp's outputs num.")); - } + // auto &true_last_op = (*this)->region(0).front().back(); + // auto &false_last_op = (*this)->region(1).front().back(); + // PADDLE_ENFORCE_EQ(true, + // true_last_op.isa(), + // phi::errors::PreconditionNotMet( + // "The last of true block must be YieldOp")); + // PADDLE_ENFORCE_EQ(true_last_op.num_operands(), + // (*this)->num_results(), + // phi::errors::PreconditionNotMet( + // "The size of last of true block op's input must be + // " "equal to IfOp's outputs num.")); + // PADDLE_ENFORCE_EQ(true, + // false_last_op.isa(), + // phi::errors::PreconditionNotMet( + // "The last of false block must be YieldOp")); + // PADDLE_ENFORCE_EQ(false_last_op.num_operands(), + // (*this)->num_results(), + // phi::errors::PreconditionNotMet( + // "The size of last of false block op's input must be + // " "equal to IfOp's outputs num.")); + // } } std::vector> IfOp::Vjp( From 41d3326af01f535384fb417c5360621e2edb9538 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 6 Dec 2023 02:44:03 +0000 Subject: [PATCH 03/19] add select_input --- .../pir/dialect/operator/ir/manual_op.cc | 40 +++++++++++++++++++ .../fluid/pir/dialect/operator/ir/manual_op.h | 11 +++++ .../pir/dialect/operator/ir/op_dialect.cc | 19 ++------- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index cda564bedbb1df..dd0d7a91697976 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -11,6 +11,17 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#ifdef GET_OP_LIST +#undef GET_OP_LIST +paddle::dialect::AddNOp, paddle::dialect::AddN_Op, + paddle::dialect::AddNWithKernelOp, paddle::dialect::FusedGemmEpilogueOp, + paddle::dialect::FusedGemmEpilogueGradOp, paddle::dialect::SplitGradOp, + paddle::dialect::ExpandOp, paddle::dialect::CreateArrayOp, + paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp, + paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp, + paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op, + paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp +#else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h" @@ -2181,6 +2192,33 @@ phi::DataType ExpandOp::GetKernelTypeForVar( return expected_kernel_dtype; } +void SelectInputOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: SelectInputOp."; + VLOG(4) << "Verifying inputs:"; + { + auto in_size = num_operands(); + IR_ENFORCE(in_size == 3u, "Size %d of inputs must be >= 3.", in_size); + IR_ENFORCE((*this) + ->operand_source(0) + .type() + .isa(), + "Type validation failed for the 0th input, but got %s.", + (*this)->operand_source(0).type()); + IR_ENFORCE( + (*this)->operand_source(1).type() == (*this)->operand_source(2).type(), + "The 1st input type %s should be equal to 2ed input type %s.", + (*this)->operand_source(1).type(), + (*this)->operand_source(2).type()); + } + VLOG(4) << "Verifying outputs:"; + { + auto out_size = num_results(); + IR_ENFORCE( + out_size == 1u, "Size %d of outputs must be equal to 1.", out_size); + } + VLOG(4) << "End Verifying for: AssignArray_Op."; +} + } // namespace dialect } // namespace paddle @@ -2199,3 +2237,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) + +#endif diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 6bdeac5bc04c9a..e3d515bb191d57 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -401,6 +401,17 @@ class ExpandOp : public pir::Op> &stop_gradients); }; +class SelectInputOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.select_input"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + void VerifySig(); + pir::Value mask() { return operand_source(0); } + pir::OpResult out() { return result(0); } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 4c44b91af35b72..002b8cb731ed7e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -64,21 +64,10 @@ void OperatorDialect::initialize() { #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT >(); - RegisterOps(); + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.cc" // NOLINT + >(); RegisterInterfaces(); } From 0ac12ccda9722c12441327ef10eccf3c8191443c Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 6 Dec 2023 07:09:22 +0000 Subject: [PATCH 04/19] fix --- .../new_executor/instruction/CMakeLists.txt | 3 +- ...{cond_instruction.cc => if_instruction.cc} | 20 +-- .../{cond_instruction.h => if_instruction.h} | 16 +- .../instruction/select_input_instruction.cc | 140 ++++++++++++++++++ .../instruction/select_input_instruction.h | 52 +++++++ .../pir_adaptor/pir_adaptor_util.h | 4 +- .../framework/new_executor/pir_interpreter.cc | 11 +- .../ir_adaptor/translator/op_translator.cc | 44 ++++++ .../pir/transforms/pd_op_to_kernel_pass.cc | 27 ++++ 9 files changed, 292 insertions(+), 25 deletions(-) rename paddle/fluid/framework/new_executor/instruction/{cond_instruction.cc => if_instruction.cc} (92%) rename paddle/fluid/framework/new_executor/instruction/{cond_instruction.h => if_instruction.h} (84%) create mode 100644 paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc create mode 100644 paddle/fluid/framework/new_executor/instruction/select_input_instruction.h diff --git a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt index abc8e86fb1663f..c85072b6a9c6cf 100644 --- a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt @@ -3,8 +3,9 @@ cc_library( SRCS instruction_base.cc phi_kernel_instruction.cc legacy_kernel_instruction.cc - cond_instruction.cc + if_instruction.cc while_instruction.cc + select_input_instruction.cc has_elements_instruction.cc tuple_push_instruction.cc tuple_pop_instruction.cc diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc similarity index 92% rename from paddle/fluid/framework/new_executor/instruction/cond_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/if_instruction.cc index a25d7d2a5a6df4..010e87e7fbca59 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/cond_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" @@ -39,11 +39,11 @@ namespace paddle { namespace framework { -CondInstruction::CondInstruction(size_t id, - const platform::Place& place, - pir::Operation* op, - ValueExecutionInfo* value_exec_info, - const std::set& skip_gc_vars) +IfInstruction::IfInstruction(size_t id, + const platform::Place& place, + pir::Operation* op, + ValueExecutionInfo* value_exec_info, + const std::set& skip_gc_vars) : InstructionBase(id, place) { PADDLE_ENFORCE( op->isa(), @@ -149,7 +149,7 @@ CondInstruction::CondInstruction(size_t id, VLOG(6) << "finish process false branch interpreter"; } -CondInstruction::~CondInstruction() { +IfInstruction::~IfInstruction() { if (true_branch_inter_ != nullptr) { delete true_branch_inter_; } @@ -158,8 +158,8 @@ CondInstruction::~CondInstruction() { } } -void CondInstruction::CopyBranchOutput( - const std::vector& var_names, const PirInterpreter* inter) { +void IfInstruction::CopyBranchOutput(const std::vector& var_names, + const PirInterpreter* inter) { for (size_t i = 0; i < var_names.size(); ++i) { auto* inner_var = inter->InnerScope()->GetVar(var_names[i]); @@ -179,7 +179,7 @@ void CondInstruction::CopyBranchOutput( } } -void CondInstruction::Run() { +void IfInstruction::Run() { DeviceContext().Wait(); if (cond_var_->Get().data()[0]) { true_branch_inter_->Run({}, false); diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/if_instruction.h similarity index 84% rename from paddle/fluid/framework/new_executor/instruction/cond_instruction.h rename to paddle/fluid/framework/new_executor/instruction/if_instruction.h index 45f39ba338814f..b18c70094a33c3 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.h @@ -27,15 +27,15 @@ class Value; class PirInterpreter; class ValueExecutionInfo; -class CondInstruction : public InstructionBase { +class IfInstruction : public InstructionBase { public: - CondInstruction(size_t id, - const platform::Place& place, - ::pir::Operation* op, - ValueExecutionInfo* value_exe_info, - const std::set& skip_gc_vars); + IfInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + ValueExecutionInfo* value_exe_info, + const std::set& skip_gc_vars); - ~CondInstruction(); + ~IfInstruction(); void Run() override; @@ -53,7 +53,7 @@ class CondInstruction : public InstructionBase { ::pir::Operation* op_; - std::string cond_name_{"cond_instruction"}; + std::string cond_name_{"if_instruction"}; Variable* cond_var_; diff --git a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc new file mode 100644 index 00000000000000..5a2001d6cda52f --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" + +namespace paddle { +namespace framework { + +SelectInputInstruction::SelectInputInstruction( + size_t id, + const platform::Place &place, + ::pir::Operation *op, + ValueExecutionInfo *value_exe_info) + : op_(op), InstructionBase(id, place) { + VLOG(6) << "construct select_input instruction"; + + std::unordered_map> inputs; + mask_ = value_exe_info->GetVarByValue(op->operand_source(0)); + inputs.emplace(op->operand_source(0), + GetValueIds(op->operand_source(0), *value_exe_info)); + + for (size_t i = 1; i < op->num_operands(); ++i) { + inputs_.push_back(value_exe_info->GetVarByValue(op->operand_source(i))); + inputs.emplace(op->operand_source(i), + GetValueIds(op->operand_source(i), *value_exe_info)); + } + SetInputs(inputs); + + std::unordered_map> outputs; + out_ = value_exe_info->GetVarByValue(op->result(0)); + outputs.emplace(op->result(0), GetValueIds(op->result(0), *value_exe_info)); + SetOutputs(outputs); +} + +inline int GetBranchNumber(const phi::DenseTensor &mask) { + PADDLE_ENFORCE_EQ( + mask.numel(), + 1, + phi::errors::Fatal("The numel of Input(Mask) in SelectInputOp or " + "SelectOutputOp must be 1. " + "But received %d, and it's shape is [%s].", + mask.numel(), + mask.dims())); + if (platform::is_cpu_place(mask.place())) { + return mask.data()[0]; + } + // when platform::is_gpu_place(mask.place()) is true + std::unique_ptr cpu_mask{new phi::DenseTensor()}; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU) + framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get()); +#else + PADDLE_THROW(phi::errors::Fatal( + "This version of PaddlePaddle does NOT support GPU, " + "but got GPU tensor 'Mask' in SelectInputOp or SelectOutputOp. " + "Please compile PaddlePaddle WITH_GPU first.")); +#endif + return cpu_mask->data()[0]; +} + +class AssignFunctor { + public: + explicit AssignFunctor(Variable *out) : out_(out) {} + + void operator()(const phi::DenseTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const phi::TensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const phi::SelectedRows &rows) const { + phi::SelectedRows &out_rows = *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + auto *m = out_rows.mutable_value(); + TensorCopy(t, t.place(), m); + } + + template + void operator()(const T &v UNUSED) const { + PADDLE_ENFORCE_EQ( + true, + false, + platform::errors::PermissionDenied( + "Not support type for assign op with type %s", typeid(T).name())); + } + + private: + void copy_tensor(const phi::DenseTensor &lod_tensor, + phi::DenseTensor *out) const { + if (!lod_tensor.IsInitialized()) return; + auto &out_tensor = *out; + TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } + + Variable *out_; +}; + +void SelectInputInstruction::Run() { + VLOG(6) << "run select_input instruction"; + auto &mask = mask_->Get(); + size_t output_branch = static_cast(GetBranchNumber(mask)); + PADDLE_ENFORCE_LT( + output_branch, + inputs_.size(), + phi::errors::Fatal( + "Input 'Mask' in SelectInputOp is invalid. " + "'Mask' must be less than the size of input vector 'X'. " + "But received Mask = %d, X's size = %d.", + output_branch, + inputs_.size())); + Variable *selected = inputs_[output_branch]; + VisitVarType(*selected, AssignFunctor(out_)); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.h b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.h new file mode 100644 index 00000000000000..16038e66152f69 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace paddle { +namespace framework { +class ValueExecutionInfo; + +class SelectInputInstruction : public InstructionBase { + public: + SelectInputInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + ValueExecutionInfo* value_exe_info); + + void Run() override; + + const std::string& Name() const override { return name_; } + + ::pir::Operation* Operation() const override { return op_; } + + private: + ::pir::Operation* op_; + + OpFuncType type_; + + std::string name_{"pd_op.select_input"}; + + Variable* mask_; // not owned + + std::vector inputs_; // not owned + + Variable* out_; // not owned +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index 26fc26a2dd3713..2dfe34b298bbd1 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -43,11 +43,11 @@ namespace paddle { namespace framework { -class CondInstruction; +class IfInstruction; class WhileInstruction; class ValueExecutionInfo { public: - friend class CondInstruction; + friend class IfInstruction; friend class WhileInstruction; explicit ValueExecutionInfo(Scope* scope) : scope_(scope) {} diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 4fc307ec660978..1fcc6fb0b6e857 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -45,10 +45,11 @@ #include "paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h" #endif -#include "paddle/fluid/framework/new_executor/instruction/cond_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/has_elements_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/while_instruction.h" @@ -666,15 +667,15 @@ void PirInterpreter::BuildInstruction() { } else if (op.dialect()->name() == "pd_op") { if (op.isa()) { auto skip_gc_vars = execution_config_.skip_gc_vars; - vec_instruction_base_.emplace_back(std::make_unique( + vec_instruction_base_.emplace_back(std::make_unique( op_idx++, place_, &op, value_exe_info_.get(), skip_gc_vars)); sub_blocks_.insert( {&op.dyn_cast().true_block(), - dynamic_cast(vec_instruction_base_.back().get()) + dynamic_cast(vec_instruction_base_.back().get()) ->TrueBranchInterpreter()}); sub_blocks_.insert( {&op.dyn_cast().false_block(), - dynamic_cast(vec_instruction_base_.back().get()) + dynamic_cast(vec_instruction_base_.back().get()) ->FalseBranchInterpreter()}); } else if (op.isa()) { auto skip_gc_vars = execution_config_.skip_gc_vars; @@ -686,6 +687,8 @@ void PirInterpreter::BuildInstruction() { ->BodyInterpreter()}); } else if (op.isa()) { CREATE_INSTR(HasElementsInstruction); + } else if (op.isa()) { + CREATE_INSTR(SelectInputInstruction); } else { PADDLE_THROW(platform::errors::Unimplemented( "Now only support pd_kernel and cinn dialect.")); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 8db987eb20fd70..70afd57a964292 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1863,6 +1863,49 @@ struct FillConstantTranscriber : public OpTranscriber { } }; +struct SelectInputOpTranscriber : public OpTranscriber { + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Block* block) override { + VLOG(10) << "[op select_input] start transcribing"; + auto op_info = this->LoopkUpOpInfo(ctx, op_desc); + + std::vector op_inputs = {}; + auto Mask_name = op_desc.Input("Mask")[0]; + auto& Input_name = op_desc.Input("X"); + IR_ENFORCE(param_map->count(Mask_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + Mask_name); + op_inputs.push_back(param_map->at(Mask_name).value); + for (auto in_name : Input_name) { + IR_ENFORCE(param_map->count(in_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + in_name); + op_inputs.push_back(param_map->at(in_name).value); + } + + pir::AttributeMap attribute_map; + + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types; + auto Out_name = op_desc.Output("Out")[0]; + VarDesc* var = op_desc.Block()->FindVarRecursive(Out_name); + arg_to_idx[var->Name()] = {0, 0}; + op_output_types.push_back(op_inputs[1].type()); + + pir::Operation* operation = pir::Operation::Create( + op_inputs, attribute_map, op_output_types, op_info); + block->push_back(operation); + RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); + + VLOG(10) << "[op assign_value] translation finished"; + return operation; + } +}; + pir::OpResult TranslateNumClassesForOneHot( pir::IrContext* ctx, TranslationContext* param_map, @@ -2726,6 +2769,7 @@ OpTranslator::OpTranslator() { special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); special_handlers["mul"] = MulOpTranscriber(); special_handlers["mul_grad"] = MulGradOpTranscriber(); + special_handlers["select_input"] = SelectInputOpTranscriber(); // To adapt LodTensorArray special_handlers["lod_array_length"] = LodArrayLengthOpTranscriber(); diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 40ca238d397b44..04c4d68933140d 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -105,6 +105,7 @@ const std::unordered_set SpecialLowerOps = { pir::TuplePushOp::name(), pir::TuplePopOp::name(), HasElementsOp::name(), + SelectInputOp::name(), "cinn_runtime.jit_kernel"}; static bool NeedFallBackCpu(const pir::Operation* op, @@ -997,6 +998,20 @@ pir::Value GetNewInput( return new_in; } +phi::Place ParsePhiPlace(pir::Type type) { + if (type.isa()) { + return type.dyn_cast().place(); + } else if (type.isa()) { + return type.dyn_cast().place(); + } else if (type.isa()) { + return type.dyn_cast().place(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "ParsePhiPlace only support AllocatedDenseTensorType or " + "AllocatedSelectedRowsType or AllocatedDenseTensorArrayType")); + } +} + void HandleForSpecialOp( const phi::Place& place, pir::Operation* op_item, @@ -1222,6 +1237,18 @@ void HandleForSpecialOp( } } + if (op_item->isa()) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + vec_inputs.push_back(new_in); + } + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_output_types.push_back(vec_inputs[1].type()); + } + } + if (op_item->name() == "cinn_runtime.jit_kernel") { if (op_item->num_operands() > 0) { for (size_t i = 0; i < op_item->num_operands(); ++i) { From 639b601a142303c89a2017b4347cbac32f71d678 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 6 Dec 2023 08:09:29 +0000 Subject: [PATCH 05/19] fix --- paddle/fluid/ir_adaptor/translator/program_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index ebd2fadf12c42b..eec45aa6e9d251 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -457,7 +457,7 @@ void ProgramTranslator::TranslateIfOperation( } auto if_op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); pir::Operation* operation = pir::Operation::Create( - if_op_inputs, attribute_map, if_op_output_types, if_op_info, 1); + if_op_inputs, attribute_map, if_op_output_types, if_op_info, 2); dst_block->push_back(operation); VLOG(4) << "[general op][conditional_block] IfOp creation end."; From 368edfb223bf45bd3a4639656e226c0e3a7b72b3 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 6 Dec 2023 08:35:10 +0000 Subject: [PATCH 06/19] refine if_op without falseblock --- .../instruction/if_instruction.cc | 76 ++++++++++------- .../new_executor/instruction/if_instruction.h | 4 +- .../framework/new_executor/pir_interpreter.cc | 10 ++- .../dialect/operator/ir/control_flow_op.cc | 81 +++++++++---------- paddle/fluid/pir/transforms/inplace_pass.cc | 6 +- .../pir/transforms/pd_op_to_kernel_pass.cc | 16 ++-- 6 files changed, 104 insertions(+), 89 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc index 010e87e7fbca59..97c017d725836c 100644 --- a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc @@ -66,13 +66,17 @@ IfInstruction::IfInstruction(size_t id, // OpOperand of IfOp, and the other is external Values used in true_block or // false_block. auto& true_branch_block = if_op.true_block(); - auto& false_branch_block = if_op.false_block(); + std::unordered_map> inputs; GetInputIds(op, *value_exec_info, &inputs); auto true_outside_inputs = GetExternalInputs(&true_branch_block, *value_exec_info, &inputs); - auto false_outside_inputs = - GetExternalInputs(&false_branch_block, *value_exec_info, &inputs); + std::vector false_outside_inputs; + if (if_op.false_region().size() != 0) { + auto& false_branch_block = if_op.false_block(); + false_outside_inputs = + GetExternalInputs(&false_branch_block, *value_exec_info, &inputs); + } SetInputs(inputs); std::unordered_map> outputs; @@ -90,8 +94,10 @@ IfInstruction::IfInstruction(size_t id, } } InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs); - InsertTuplePushContinerToOuts( - &false_branch_block, *value_exec_info, &outputs); + if (if_op.false_region().size() != 0) { + InsertTuplePushContinerToOuts( + &if_op.false_block(), *value_exec_info, &outputs); + } SetOutputs(outputs); VLOG(6) << "finish process inputs outputs index"; @@ -122,30 +128,37 @@ IfInstruction::IfInstruction(size_t id, true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set); VLOG(6) << "finish process true branch interpreter"; - Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); - false_branch_inter_ = - new PirInterpreter(place, - {}, - &false_branch_block, - false_scope, - value_exec_info->NewChild(false_scope), - {}); - - std::set false_skip_gc_names_set; - for (auto value : GetYiedOpInputs(&false_branch_block)) { - false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value)); - false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); - false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); - } - for (auto value : false_outside_inputs) { - false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); - false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); - } - for (auto var_name : skip_gc_vars) { - false_skip_gc_names_.push_back(var_name); - false_skip_gc_names_set.insert(var_name); + if (if_op.false_region().size() != 0) { + auto& false_branch_block = if_op.false_block(); + Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); + false_branch_inter_ = + new PirInterpreter(place, + {}, + &if_op.false_block(), + false_scope, + value_exec_info->NewChild(false_scope), + {}); + std::set false_skip_gc_names_set; + for (auto value : GetYiedOpInputs(&false_branch_block)) { + false_branch_outputs_.push_back( + false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_.push_back( + false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert( + false_branch_inter_->GetNameByValue(value)); + } + for (auto value : false_outside_inputs) { + false_skip_gc_names_.push_back( + false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert( + false_branch_inter_->GetNameByValue(value)); + } + for (auto var_name : skip_gc_vars) { + false_skip_gc_names_.push_back(var_name); + false_skip_gc_names_set.insert(var_name); + } + false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); } - false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); VLOG(6) << "finish process false branch interpreter"; } @@ -185,10 +198,11 @@ void IfInstruction::Run() { true_branch_inter_->Run({}, false); CopyBranchOutput(true_branch_outputs_, true_branch_inter_); } else { - false_branch_inter_->Run({}, false); - CopyBranchOutput(false_branch_outputs_, false_branch_inter_); + if (false_branch_inter_) { + false_branch_inter_->Run({}, false); + CopyBranchOutput(false_branch_outputs_, false_branch_inter_); + } } - // copy ouptut } diff --git a/paddle/fluid/framework/new_executor/instruction/if_instruction.h b/paddle/fluid/framework/new_executor/instruction/if_instruction.h index b18c70094a33c3..e6d1fc4723c5d6 100644 --- a/paddle/fluid/framework/new_executor/instruction/if_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.h @@ -59,9 +59,9 @@ class IfInstruction : public InstructionBase { std::vector output_vars_; - PirInterpreter* true_branch_inter_; + PirInterpreter* true_branch_inter_ = nullptr; - PirInterpreter* false_branch_inter_; + PirInterpreter* false_branch_inter_ = nullptr; std::vector true_branch_outputs_; diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 45674498b179fb..ac78b1ee4184c0 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -678,10 +678,12 @@ void PirInterpreter::BuildInstruction() { {&op.dyn_cast().true_block(), dynamic_cast(vec_instruction_base_.back().get()) ->TrueBranchInterpreter()}); - sub_blocks_.insert( - {&op.dyn_cast().false_block(), - dynamic_cast(vec_instruction_base_.back().get()) - ->FalseBranchInterpreter()}); + if (op.dyn_cast().false_region().size() != 0) { + sub_blocks_.insert( + {&op.dyn_cast().false_block(), + dynamic_cast(vec_instruction_base_.back().get()) + ->FalseBranchInterpreter()}); + } } else if (op.isa()) { auto skip_gc_vars = execution_config_.skip_gc_vars; vec_instruction_base_.emplace_back(std::make_unique( diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 26a0a6d559d3e8..f9c740e7c120d2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -156,53 +156,48 @@ void IfOp::VerifySig() { "bool DenseTensorType.")); } - // PADDLE_ENFORCE_EQ((*this)->num_regions(), - // 2u, - // phi::errors::PreconditionNotMet( - // "The size %d of regions must be equal to 2.", - // (*this)->num_regions())); + PADDLE_ENFORCE_EQ((*this)->num_regions(), + 2u, + phi::errors::PreconditionNotMet( + "The size %d of regions must be equal to 2.", + (*this)->num_regions())); } void IfOp::VerifyRegion() { - // VLOG(4) << "Start Verifying sub regions for: IfOp."; - // PADDLE_ENFORCE_EQ( - // (*this)->region(0).size(), - // 1u, - // phi::errors::PreconditionNotMet("The size %d of true_region must - // be 1.", - // (*this)->region(0).size())); - - // if ((*this)->num_results() != 0) { - // PADDLE_ENFORCE_EQ( - // (*this)->region(0).size(), - // (*this)->region(1).size(), - // phi::errors::PreconditionNotMet("The size %d of true_region must be " - // "equal to the size %d of - // false_region.", - // (*this)->region(0).size(), - // (*this)->region(1).size())); + VLOG(4) << "Start Verifying sub regions for: IfOp."; + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", + (*this)->region(0).size())); + auto &true_last_op = (*this)->region(0).front().back(); + PADDLE_ENFORCE_EQ(true, + true_last_op.isa(), + phi::errors::PreconditionNotMet( + "The last of true block must be YieldOp")); + PADDLE_ENFORCE_EQ(true_last_op.num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of true block op's input must be " + "equal to IfOp's outputs num.")); - // auto &true_last_op = (*this)->region(0).front().back(); - // auto &false_last_op = (*this)->region(1).front().back(); - // PADDLE_ENFORCE_EQ(true, - // true_last_op.isa(), - // phi::errors::PreconditionNotMet( - // "The last of true block must be YieldOp")); - // PADDLE_ENFORCE_EQ(true_last_op.num_operands(), - // (*this)->num_results(), - // phi::errors::PreconditionNotMet( - // "The size of last of true block op's input must be - // " "equal to IfOp's outputs num.")); - // PADDLE_ENFORCE_EQ(true, - // false_last_op.isa(), - // phi::errors::PreconditionNotMet( - // "The last of false block must be YieldOp")); - // PADDLE_ENFORCE_EQ(false_last_op.num_operands(), - // (*this)->num_results(), - // phi::errors::PreconditionNotMet( - // "The size of last of false block op's input must be - // " "equal to IfOp's outputs num.")); - // } + if ((*this)->region(1).size() != 0) { + PADDLE_ENFORCE_EQ((*this)->region(1).size(), + 1u, + phi::errors::PreconditionNotMet( + "The size %d of false_region must be 1.", + (*this)->region(0).size())); + auto &false_last_op = (*this)->region(1).front().back(); + PADDLE_ENFORCE_EQ(true, + false_last_op.isa(), + phi::errors::PreconditionNotMet( + "The last of false block must be YieldOp")); + PADDLE_ENFORCE_EQ(false_last_op.num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of false block op's input must be " + "equal to IfOp's outputs num.")); + } } std::vector> IfOp::Vjp( diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index 3088f41f240993..f4269d5d857eb6 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -250,8 +250,10 @@ static void GetEagerDelValueOfOp( auto if_op = op.dyn_cast(); GetEagerDelValueOfOp(&if_op.true_block(), skip_dels, del_value_2_op); VLOG(8) << "GetEagerDelValueOfOp for IfOp true block"; - GetEagerDelValueOfOp(&if_op.false_block(), skip_dels, del_value_2_op); - VLOG(8) << "GetEagerDelValueOfOp for IfOp false block"; + if (if_op.false_region().size() != 0) { + GetEagerDelValueOfOp(&if_op.false_block(), skip_dels, del_value_2_op); + VLOG(8) << "GetEagerDelValueOfOp for IfOp false block"; + } } } } diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 04c4d68933140d..05f073b9b31e41 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -917,13 +917,15 @@ void HandleForIfOp( map_value_pair); // process false block - auto& false_block = new_ifop.false_block(); - ProcessBlock(place, - &old_ifop.false_block(), - &false_block, - ctx, - map_op_pair, - map_value_pair); + if (old_ifop.false_region().size() != 0) { + auto& false_block = new_ifop.false_block(); + ProcessBlock(place, + &old_ifop.false_block(), + &false_block, + ctx, + map_op_pair, + map_value_pair); + } // update map (*map_op_pair)[op_item] = new_ifop; From 0346fdd2c404dd443191c335700d915a4ee33d07 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 6 Dec 2023 09:12:36 +0000 Subject: [PATCH 07/19] fix bug --- paddle/fluid/pir/dialect/operator/ir/manual_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index e3d515bb191d57..555f87f1dda139 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -410,7 +410,7 @@ class SelectInputOp : public pir::Op { void VerifySig(); pir::Value mask() { return operand_source(0); } pir::OpResult out() { return result(0); } -} +}; } // namespace dialect } // namespace paddle From 2ecd26982f516d59e4f5e47c4a3cb682724a6a11 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 6 Dec 2023 10:33:14 +0000 Subject: [PATCH 08/19] fix --- paddle/fluid/pir/dialect/operator/ir/manual_op.cc | 2 +- paddle/fluid/pir/dialect/operator/ir/manual_op.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index dd0d7a91697976..81532960af378b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -2237,5 +2237,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) - +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) #endif diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 555f87f1dda139..460356039d84ab 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -430,3 +430,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) From 14a711410ba6a585a590e143d84aa66a86fc5e99 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 6 Dec 2023 11:44:58 +0000 Subject: [PATCH 09/19] fix --- .../instruction/select_input_instruction.cc | 2 +- .../instruction/while_instruction.cc | 22 +++++++------------ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc index 5a2001d6cda52f..893915f841d7fc 100644 --- a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc @@ -25,7 +25,7 @@ SelectInputInstruction::SelectInputInstruction( const platform::Place &place, ::pir::Operation *op, ValueExecutionInfo *value_exe_info) - : op_(op), InstructionBase(id, place) { + : InstructionBase(id, place), op_(op) { VLOG(6) << "construct select_input instruction"; std::unordered_map> inputs; diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc index aee25e8d816843..2f3787118d2e48 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc @@ -175,20 +175,14 @@ void WhileInstruction::CopyOutputsToBlockArgs() { auto* dst_tensor_array = inner_var->GetMutable(); dst_tensor_array->set_type(src_tensor_array.dtype()); dst_tensor_array->set_layout(src_tensor_array.layout()); - if (dst_tensor_array->empty()) { - for (auto src_tensor : src_tensor_array) { - phi::DenseTensor* tmp_dst_tensor = new phi::DenseTensor(); - tmp_dst_tensor->set_meta(src_tensor.meta()); - framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor); - dst_tensor_array->push_back(*tmp_dst_tensor); - } - } else { - for (size_t id = 0; id < dst_tensor_array->size(); id++) { - auto& src_tensor = src_tensor_array[id]; - phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id); - tmp_dst_tensor->set_meta(src_tensor.meta()); - framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor); - } + while (dst_tensor_array->size() < src_tensor_array.size()) { + dst_tensor_array->emplace_back(); + } + for (size_t id = 0; id < dst_tensor_array->size(); id++) { + auto& src_tensor = src_tensor_array[id]; + phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id); + tmp_dst_tensor->set_meta(src_tensor.meta()); + framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor); } } else { PADDLE_THROW( From 29b9d2516fc52837c6006944c6a3e12b2cf6eeda Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 7 Dec 2023 03:29:31 +0000 Subject: [PATCH 10/19] fix --- .../translator/program_translator.cc | 3 +- test/cpp/pir/core/program_translator_test.cc | 76 +++++-------------- 2 files changed, 20 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index eec45aa6e9d251..b2cea44aeb0a86 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -474,8 +474,7 @@ void ProgramTranslator::TranslateIfOperation( &true_region.front()); // insert yeild op to true block auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); - std::vector yeild_inputs{ - true_block_context->at(cond_op_cond).value}; + std::vector yeild_inputs; for (auto& out_name : cond_op_outputs) { yeild_inputs.push_back(true_block_context->at(out_name).value); } diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index 0c199f0481e710..68baecdc7f73ba 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -85,7 +85,10 @@ TEST(OperatorDialectTest, ConditionBlock) { ctx->GetOrRegisterDialect(); auto program = paddle::TranslateLegacyProgramToProgram(p); - EXPECT_EQ(program->block()->size(), 4u); + program->Print(std::cout); + std::cout << std::endl; + + EXPECT_EQ(program->block()->size(), 9u); size_t id = 0; for (auto &op : *program->block()) { if (id == 0 || id == 1) { @@ -117,77 +120,36 @@ TEST(OperatorDialectTest, ConditionBlock) { EXPECT_EQ(op2.isa(), true); } if (true_true_id == 1) { - EXPECT_EQ(op2.isa(), true); - } - true_true_id++; - } - auto &false_false_block = - op1.dyn_cast().false_block(); - size_t false_false_id = 0; - for (auto &op2 : false_false_block) { - if (false_false_id == 0) { - EXPECT_EQ(op2.isa(), true); + EXPECT_EQ(op2.isa(), true); } - if (false_false_id == 1) { + if (true_true_id == 2) { EXPECT_EQ(op2.isa(), true); } - false_false_id++; + true_true_id++; } } if (true_id == 4) { - EXPECT_EQ(op1.isa(), true); + EXPECT_EQ(op1.isa(), true); } if (true_id == 5) { - EXPECT_EQ(op1.isa(), true); - } - true_id++; - } - // false block - auto &false_block = op.dyn_cast().false_block(); - size_t false_id = 0; - for (auto &op1 : false_block) { - if (false_id == 0 || false_id == 1) { - EXPECT_EQ(op1.isa(), true); + EXPECT_EQ(op1.isa(), true); } - if (false_id == 2) { - EXPECT_EQ(op1.isa(), true); + if (true_id == 6) { + EXPECT_EQ(op1.isa(), true); } - if (false_id == 3) { - EXPECT_EQ(op1.isa(), true); - // true block - auto &false_true_block = - op1.dyn_cast().true_block(); - size_t false_true_id = 0; - for (auto &op2 : false_true_block) { - if (false_true_id == 0) { - EXPECT_EQ(op2.isa(), true); - } - if (false_true_id == 1) { - EXPECT_EQ(op2.isa(), true); - } - false_true_id++; - } - // false block - auto &false_false_block = - op1.dyn_cast().true_block(); - size_t false_false_id = 0; - for (auto &op2 : false_false_block) { - if (false_false_id == 0) { - EXPECT_EQ(op2.isa(), true); - } - if (false_false_id == 1) { - EXPECT_EQ(op2.isa(), true); - } - false_false_id++; - } + if (true_id == 7) { + EXPECT_EQ(op1.isa(), true); } - if (false_id == 4) { + if (true_id == 8) { EXPECT_EQ(op1.isa(), true); } - if (false_id == 5) { + if (true_id == 9 || true_id == 10) { + EXPECT_EQ(op1.isa(), true); + } + if (true_id == 11) { EXPECT_EQ(op1.isa(), true); } - false_id++; + true_id++; } } id++; From 7b61ae03ed3aeeab679c4348451b8bd02476a435 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 7 Dec 2023 11:58:14 +0000 Subject: [PATCH 11/19] fix --- .../instruction/if_instruction.cc | 75 ++++++++----------- .../framework/new_executor/pir_interpreter.cc | 10 +-- .../ir_adaptor/translator/op_translator.cc | 53 ++++++++++++- .../translator/program_translator.cc | 52 ++++++++++++- .../translator/program_translator.h | 3 + .../dialect/operator/ir/control_flow_op.cc | 14 +++- .../pir/dialect/operator/ir/manual_op.cc | 72 +++++++++++++++--- paddle/fluid/pir/transforms/inplace_pass.cc | 6 +- .../pir/transforms/pd_op_to_kernel_pass.cc | 16 ++-- 9 files changed, 223 insertions(+), 78 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc index 97c017d725836c..3ac3a9e4780be3 100644 --- a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc @@ -72,11 +72,9 @@ IfInstruction::IfInstruction(size_t id, auto true_outside_inputs = GetExternalInputs(&true_branch_block, *value_exec_info, &inputs); std::vector false_outside_inputs; - if (if_op.false_region().size() != 0) { - auto& false_branch_block = if_op.false_block(); - false_outside_inputs = - GetExternalInputs(&false_branch_block, *value_exec_info, &inputs); - } + auto& false_branch_block = if_op.false_block(); + false_outside_inputs = + GetExternalInputs(&false_branch_block, *value_exec_info, &inputs); SetInputs(inputs); std::unordered_map> outputs; @@ -94,10 +92,10 @@ IfInstruction::IfInstruction(size_t id, } } InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs); - if (if_op.false_region().size() != 0) { - InsertTuplePushContinerToOuts( - &if_op.false_block(), *value_exec_info, &outputs); - } + + InsertTuplePushContinerToOuts( + &if_op.false_block(), *value_exec_info, &outputs); + SetOutputs(outputs); VLOG(6) << "finish process inputs outputs index"; @@ -128,37 +126,30 @@ IfInstruction::IfInstruction(size_t id, true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set); VLOG(6) << "finish process true branch interpreter"; - if (if_op.false_region().size() != 0) { - auto& false_branch_block = if_op.false_block(); - Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); - false_branch_inter_ = - new PirInterpreter(place, - {}, - &if_op.false_block(), - false_scope, - value_exec_info->NewChild(false_scope), - {}); - std::set false_skip_gc_names_set; - for (auto value : GetYiedOpInputs(&false_branch_block)) { - false_branch_outputs_.push_back( - false_branch_inter_->GetNameByValue(value)); - false_skip_gc_names_.push_back( - false_branch_inter_->GetNameByValue(value)); - false_skip_gc_names_set.insert( - false_branch_inter_->GetNameByValue(value)); - } - for (auto value : false_outside_inputs) { - false_skip_gc_names_.push_back( - false_branch_inter_->GetNameByValue(value)); - false_skip_gc_names_set.insert( - false_branch_inter_->GetNameByValue(value)); - } - for (auto var_name : skip_gc_vars) { - false_skip_gc_names_.push_back(var_name); - false_skip_gc_names_set.insert(var_name); - } - false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); + Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); + false_branch_inter_ = + new PirInterpreter(place, + {}, + &if_op.false_block(), + false_scope, + value_exec_info->NewChild(false_scope), + {}); + std::set false_skip_gc_names_set; + for (auto value : GetYiedOpInputs(&false_branch_block)) { + false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); + } + for (auto value : false_outside_inputs) { + false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); + } + for (auto var_name : skip_gc_vars) { + false_skip_gc_names_.push_back(var_name); + false_skip_gc_names_set.insert(var_name); } + false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); + VLOG(6) << "finish process false branch interpreter"; } @@ -198,10 +189,8 @@ void IfInstruction::Run() { true_branch_inter_->Run({}, false); CopyBranchOutput(true_branch_outputs_, true_branch_inter_); } else { - if (false_branch_inter_) { - false_branch_inter_->Run({}, false); - CopyBranchOutput(false_branch_outputs_, false_branch_inter_); - } + false_branch_inter_->Run({}, false); + CopyBranchOutput(false_branch_outputs_, false_branch_inter_); } // copy ouptut } diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index ac78b1ee4184c0..45674498b179fb 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -678,12 +678,10 @@ void PirInterpreter::BuildInstruction() { {&op.dyn_cast().true_block(), dynamic_cast(vec_instruction_base_.back().get()) ->TrueBranchInterpreter()}); - if (op.dyn_cast().false_region().size() != 0) { - sub_blocks_.insert( - {&op.dyn_cast().false_block(), - dynamic_cast(vec_instruction_base_.back().get()) - ->FalseBranchInterpreter()}); - } + sub_blocks_.insert( + {&op.dyn_cast().false_block(), + dynamic_cast(vec_instruction_base_.back().get()) + ->FalseBranchInterpreter()}); } else if (op.isa()) { auto skip_gc_vars = execution_config_.skip_gc_vars; vec_instruction_base_.emplace_back(std::make_unique( diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 93036478d74850..eefe7ebe79db65 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1863,6 +1863,23 @@ struct FillConstantTranscriber : public OpTranscriber { } }; +static std::vector ParseCompatibleShapes( + const std::vector& dim1, const std::vector& dim2) { + IR_ENFORCE(dim1.size() == dim2.size(), + "Does not support rank inconsistency: dim1=%d, dim2=%d", + dim1.size(), + dim2.size()); + std::vector result; + for (size_t i = 0; i < dim1.size(); ++i) { + if (dim1[i] != dim2[i]) { + result.push_back(-1); + } else { + result.push_back(dim1[i]); + } + } + return result; +} + struct SelectInputOpTranscriber : public OpTranscriber { pir::Operation* operator()(pir::IrContext* ctx, TranslationContext* param_map, @@ -1894,7 +1911,41 @@ struct SelectInputOpTranscriber : public OpTranscriber { auto Out_name = op_desc.Output("Out")[0]; VarDesc* var = op_desc.Block()->FindVarRecursive(Out_name); arg_to_idx[var->Name()] = {0, 0}; - op_output_types.push_back(op_inputs[1].type()); + + // NOTE(zhangbo): Only support + auto input1 = op_inputs[1].type(); + auto input2 = op_inputs[2].type(); + if (input1 == input2) { + op_output_types.push_back(op_inputs[1].type()); + } else { + if (input1.isa() && + input2.isa()) { + auto tensor1 = input1.dyn_cast(); + auto tensor2 = input2.dyn_cast(); + if (tensor1.dtype() != tensor2.dtype() || + tensor1.data_layout() != tensor2.data_layout() || + tensor1.lod() != tensor2.lod() || + tensor1.offset() != tensor2.offset()) { + IR_THROW( + "select_input only support same type or DenseTensorType with " + "only different dim."); + } + auto dim1 = input1.dyn_cast().dims(); + auto dim2 = input2.dyn_cast().dims(); + std::vector compat_shape = ParseCompatibleShapes( + common::vectorize(dim1), common::vectorize(dim2)); + op_output_types.push_back(paddle::dialect::DenseTensorType::get( + ctx, + tensor1.dtype(), + common::make_ddim(compat_shape), + tensor1.data_layout(), + tensor1.lod(), + tensor1.offset())); + } + IR_THROW( + "select_input only support same type or DenseTensorType with only " + "different dim."); + } pir::Operation* operation = pir::Operation::Create( op_inputs, attribute_map, op_output_types, op_info); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index b2cea44aeb0a86..e3e5c4196e31fd 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -463,6 +463,7 @@ void ProgramTranslator::TranslateIfOperation( VLOG(4) << "[general op][conditional_block] IfOp creation end."; if (op->GetBlockAttrId("sub_block") != -1) { + // Translate true branch by sub_block. auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); pir::Region& true_region = operation->region(0); if (true_region.empty()) true_region.emplace_back(); @@ -474,12 +475,36 @@ void ProgramTranslator::TranslateIfOperation( &true_region.front()); // insert yeild op to true block auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); - std::vector yeild_inputs; + std::vector true_yeild_inputs; for (auto& out_name : cond_op_outputs) { - yeild_inputs.push_back(true_block_context->at(out_name).value); + true_yeild_inputs.push_back(true_block_context->at(out_name).value); } true_region.front().push_back( - pir::Operation::Create(yeild_inputs, {}, {}, yeild_info)); + pir::Operation::Create(true_yeild_inputs, {}, {}, yeild_info)); + + // NOTE(zhangbo): The if_op of PIR requires that both true and false + // branches must exist, and the number of outputs and dtypes must be + // consistent. Only inconsistent shape is allowed. To be compatible with the + // old IR design, only true branches are allowed. The false branch may + // require yeild some fake variables. + pir::Region& false_region = operation->region(1); + if (false_region.empty()) false_region.emplace_back(); + auto* false_block_context = translation_ctx->CreateInnerContext(); + std::vector false_yeild_inputs; + for (size_t id = 0; id < cond_op_outputs.size(); id++) { + if (false_block_context->count(cond_op_outputs[id]) == 0) { + auto true_type = true_yeild_inputs[id].type(); + if (true_type.isa()) { + InsertFullOpToBlock(&false_region.front(), true_type); + } else { + CreateUndefinedVariable(cond_op_outputs[id], sub_block); + } + } + false_yeild_inputs.push_back( + false_block_context->at(cond_op_outputs[id]).value); + } + false_region.front().push_back( + pir::Operation::Create(false_yeild_inputs, {}, {}, yeild_info)); } VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; @@ -744,11 +769,13 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( } } } + const VariableDefiningInfo& ProgramTranslator::GetValueOrCreateInTop( const std::string& var_name, TranslationContext* translation_ctx) { if (translation_ctx->Has(var_name)) return translation_ctx->at(var_name); return CreateUndefinedVariable(var_name, legacy_program_->Block(0)); } + const VariableDefiningInfo& ProgramTranslator::CreateUndefinedVariable( const std::string& var_name, const BlockDesc& block) { VLOG(10) << "[undefined variable]" << var_name; @@ -771,6 +798,25 @@ const VariableDefiningInfo& ProgramTranslator::CreateUndefinedVariable( param_map_.PushValue(var_name, val); return param_map_.at(var_name); } + +const VariableDefiningInfo& ProgramTranslator::InsertFullOpToBlock( + pir::Block* insert_block, pir::Type type) { + PADDLE_ENFORCE_EQ( + type.isa(), + true, + platform::errors::InvalidArgument( + "only support insert FullOp for DenseTensorType, but now is %s", + type)); + pir::Builder builder(ctx_, insert_block, insert_block->begin()); + auto tensor_type = type.dyn_cast(); + std::vector shape = common::vectorize(tensor_type.dims()); + paddle::dialect::FullOp full_op = builder.Build( + shape, + 0, + paddle::dialect::TransToPhiDataType(tensor_type.dtype()), + phi::CPUPlace()); +} + void ProgramTranslator::SetIsPersisableAttributeForAllValue( const BlockDesc& block) { // Currently we set is persisable for operation that generated a value diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index a0e01bea1caf0c..178892b84173e2 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -165,6 +165,9 @@ class ProgramTranslator { const VariableDefiningInfo& CreateUndefinedVariable( const std::string& var_name, const BlockDesc& block); + const VariableDefiningInfo& InsertFullOpToBlock(pir::Block* insert_block, + pir::Type type); + void TranslateIfOperation(const OpDesc* op, TranslationContext* translation_ctx, pir::Block* dst_block); diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index f9c740e7c120d2..6527a509821731 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -165,33 +165,45 @@ void IfOp::VerifySig() { void IfOp::VerifyRegion() { VLOG(4) << "Start Verifying sub regions for: IfOp."; + (*this)->Print(std::cout); + std::cout << std::endl; + VLOG(4) << "Start Verifying true branch."; PADDLE_ENFORCE_EQ( (*this)->region(0).size(), 1u, phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", (*this)->region(0).size())); + VLOG(0) << "Start Verifying true branch 1."; auto &true_last_op = (*this)->region(0).front().back(); + VLOG(0) << "Start Verifying true branch 2."; PADDLE_ENFORCE_EQ(true, true_last_op.isa(), phi::errors::PreconditionNotMet( "The last of true block must be YieldOp")); + VLOG(0) << "Start Verifying true branch 3."; PADDLE_ENFORCE_EQ(true_last_op.num_operands(), (*this)->num_results(), phi::errors::PreconditionNotMet( "The size of last of true block op's input must be " "equal to IfOp's outputs num.")); - + VLOG(0) << "Start Verifying true branch 4."; if ((*this)->region(1).size() != 0) { + VLOG(0) << "size block: " << (*this)->region(1).size(); + VLOG(0) << "size op: " << (*this)->region(1).front().size(); + VLOG(4) << "Start Verifying false branch."; PADDLE_ENFORCE_EQ((*this)->region(1).size(), 1u, phi::errors::PreconditionNotMet( "The size %d of false_region must be 1.", (*this)->region(0).size())); + VLOG(0) << "Start Verifying true branch 5."; auto &false_last_op = (*this)->region(1).front().back(); + VLOG(0) << "Start Verifying true branch 6."; PADDLE_ENFORCE_EQ(true, false_last_op.isa(), phi::errors::PreconditionNotMet( "The last of false block must be YieldOp")); + VLOG(0) << "Start Verifying true branch 7."; PADDLE_ENFORCE_EQ(false_last_op.num_operands(), (*this)->num_results(), phi::errors::PreconditionNotMet( diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 81532960af378b..7e8b1455a62825 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -2198,17 +2198,67 @@ void SelectInputOp::VerifySig() { { auto in_size = num_operands(); IR_ENFORCE(in_size == 3u, "Size %d of inputs must be >= 3.", in_size); - IR_ENFORCE((*this) - ->operand_source(0) - .type() - .isa(), - "Type validation failed for the 0th input, but got %s.", - (*this)->operand_source(0).type()); - IR_ENFORCE( - (*this)->operand_source(1).type() == (*this)->operand_source(2).type(), - "The 1st input type %s should be equal to 2ed input type %s.", - (*this)->operand_source(1).type(), - (*this)->operand_source(2).type()); + auto input1 = (*this)->operand_source(1).type(); + auto input2 = (*this)->operand_source(2).type(); + if (input1.isa() && + input2.isa()) { + auto tensor1 = input1.dyn_cast(); + auto tensor2 = input1.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + } else if (input1.isa() && + input2.isa()) { + auto tensor1 = + input1.dyn_cast(); + auto tensor2 = + input1.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + IR_ENFORCE( + tensor1.place() == tensor2.place(), + "The 1st input place %s should be equal to 2ed input place %s.", + tensor1.place(), + tensor2.place()); + } else { + IR_ENFORCE(input1 == input2, + "The 1st input type %s should be equal to 2ed input type %s.", + input1, + input2); + } } VLOG(4) << "Verifying outputs:"; { diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index f4269d5d857eb6..3088f41f240993 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -250,10 +250,8 @@ static void GetEagerDelValueOfOp( auto if_op = op.dyn_cast(); GetEagerDelValueOfOp(&if_op.true_block(), skip_dels, del_value_2_op); VLOG(8) << "GetEagerDelValueOfOp for IfOp true block"; - if (if_op.false_region().size() != 0) { - GetEagerDelValueOfOp(&if_op.false_block(), skip_dels, del_value_2_op); - VLOG(8) << "GetEagerDelValueOfOp for IfOp false block"; - } + GetEagerDelValueOfOp(&if_op.false_block(), skip_dels, del_value_2_op); + VLOG(8) << "GetEagerDelValueOfOp for IfOp false block"; } } } diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 05f073b9b31e41..04c4d68933140d 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -917,15 +917,13 @@ void HandleForIfOp( map_value_pair); // process false block - if (old_ifop.false_region().size() != 0) { - auto& false_block = new_ifop.false_block(); - ProcessBlock(place, - &old_ifop.false_block(), - &false_block, - ctx, - map_op_pair, - map_value_pair); - } + auto& false_block = new_ifop.false_block(); + ProcessBlock(place, + &old_ifop.false_block(), + &false_block, + ctx, + map_op_pair, + map_value_pair); // update map (*map_op_pair)[op_item] = new_ifop; From 6e1127fef5324964b3b846dd59123ddb9561e9fa Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 7 Dec 2023 12:01:07 +0000 Subject: [PATCH 12/19] fix --- .../dialect/operator/ir/control_flow_op.cc | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 6527a509821731..cbe4654c707d49 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -187,29 +187,27 @@ void IfOp::VerifyRegion() { "The size of last of true block op's input must be " "equal to IfOp's outputs num.")); VLOG(0) << "Start Verifying true branch 4."; - if ((*this)->region(1).size() != 0) { - VLOG(0) << "size block: " << (*this)->region(1).size(); - VLOG(0) << "size op: " << (*this)->region(1).front().size(); - VLOG(4) << "Start Verifying false branch."; - PADDLE_ENFORCE_EQ((*this)->region(1).size(), - 1u, - phi::errors::PreconditionNotMet( - "The size %d of false_region must be 1.", - (*this)->region(0).size())); - VLOG(0) << "Start Verifying true branch 5."; - auto &false_last_op = (*this)->region(1).front().back(); - VLOG(0) << "Start Verifying true branch 6."; - PADDLE_ENFORCE_EQ(true, - false_last_op.isa(), - phi::errors::PreconditionNotMet( - "The last of false block must be YieldOp")); - VLOG(0) << "Start Verifying true branch 7."; - PADDLE_ENFORCE_EQ(false_last_op.num_operands(), - (*this)->num_results(), - phi::errors::PreconditionNotMet( - "The size of last of false block op's input must be " - "equal to IfOp's outputs num.")); - } + VLOG(0) << "size block: " << (*this)->region(1).size(); + VLOG(0) << "size op: " << (*this)->region(1).front().size(); + VLOG(4) << "Start Verifying false branch."; + PADDLE_ENFORCE_EQ( + (*this)->region(1).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of false_region must be 1.", + (*this)->region(0).size())); + VLOG(0) << "Start Verifying true branch 5."; + auto &false_last_op = (*this)->region(1).front().back(); + VLOG(0) << "Start Verifying true branch 6."; + PADDLE_ENFORCE_EQ(true, + false_last_op.isa(), + phi::errors::PreconditionNotMet( + "The last of false block must be YieldOp")); + VLOG(0) << "Start Verifying true branch 7."; + PADDLE_ENFORCE_EQ(false_last_op.num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of false block op's input must be " + "equal to IfOp's outputs num.")); } std::vector> IfOp::Vjp( From 48cb76611b2f75d45fbcf8da0405e9268b713ddd Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 7 Dec 2023 12:39:28 +0000 Subject: [PATCH 13/19] fix --- .../translator/program_translator.cc | 53 +++++++++++-------- .../translator/program_translator.h | 4 +- .../pir/dialect/operator/ir/manual_op.cc | 1 + 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index e3e5c4196e31fd..7555787aa47752 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -424,6 +424,27 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, } } +pir::Operation* ProgramTranslator::InsertFullOrDataOpToBlock( + pir::Block* insert_block, pir::Type type) { + pir::Builder builder(ctx_, insert_block, insert_block->begin()); + if (type.isa()) { + auto tensor_type = type.dyn_cast(); + std::vector shape = common::vectorize(tensor_type.dims()); + paddle::dialect::FullOp full_op = builder.Build( + shape, + 0, + paddle::dialect::TransToPhiDataType(tensor_type.dtype()), + phi::CPUPlace()); + return full_op.operation(); + } else if (type.isa()) { + auto array_type = type.dyn_cast(); + paddle::dialect::CreateArrayOp array_op = + builder.Build(array_type.dtype()); + return array_op.operation(); + } + return nullptr; +} + // NOTE(zhangbo): All condition_block_op will be translated as an if_op with // only a true branch. void ProgramTranslator::TranslateIfOperation( @@ -494,11 +515,15 @@ void ProgramTranslator::TranslateIfOperation( for (size_t id = 0; id < cond_op_outputs.size(); id++) { if (false_block_context->count(cond_op_outputs[id]) == 0) { auto true_type = true_yeild_inputs[id].type(); - if (true_type.isa()) { - InsertFullOpToBlock(&false_region.front(), true_type); - } else { - CreateUndefinedVariable(cond_op_outputs[id], sub_block); - } + pir::Operation* init_op = + InsertFullOrDataOpToBlock(&false_region.front(), true_type); + PADDLE_ENFORCE_NOT_NULL( + init_op, + phi::errors::PreconditionNotMet( + "Only support insert full or data op for DenseTensor or " + "DenseTensorArray to false block failed.")); + false_block_context->PushValue( + cond_op_outputs[id], VariableDefiningInfo(init_op->result(i))); } false_yeild_inputs.push_back( false_block_context->at(cond_op_outputs[id]).value); @@ -799,24 +824,6 @@ const VariableDefiningInfo& ProgramTranslator::CreateUndefinedVariable( return param_map_.at(var_name); } -const VariableDefiningInfo& ProgramTranslator::InsertFullOpToBlock( - pir::Block* insert_block, pir::Type type) { - PADDLE_ENFORCE_EQ( - type.isa(), - true, - platform::errors::InvalidArgument( - "only support insert FullOp for DenseTensorType, but now is %s", - type)); - pir::Builder builder(ctx_, insert_block, insert_block->begin()); - auto tensor_type = type.dyn_cast(); - std::vector shape = common::vectorize(tensor_type.dims()); - paddle::dialect::FullOp full_op = builder.Build( - shape, - 0, - paddle::dialect::TransToPhiDataType(tensor_type.dtype()), - phi::CPUPlace()); -} - void ProgramTranslator::SetIsPersisableAttributeForAllValue( const BlockDesc& block) { // Currently we set is persisable for operation that generated a value diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 178892b84173e2..052a8fa13cea41 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -165,8 +165,8 @@ class ProgramTranslator { const VariableDefiningInfo& CreateUndefinedVariable( const std::string& var_name, const BlockDesc& block); - const VariableDefiningInfo& InsertFullOpToBlock(pir::Block* insert_block, - pir::Type type); + pir::Operation* InsertFullOrDataOpToBlock(pir::Block* insert_block, + pir::Type type); void TranslateIfOperation(const OpDesc* op, TranslationContext* translation_ctx, diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 7e8b1455a62825..2160e56442d465 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -24,6 +24,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" From 5a6c28f1772fc4d2cabaef16a3c59d043911ef34 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 7 Dec 2023 13:17:36 +0000 Subject: [PATCH 14/19] fix --- paddle/fluid/ir_adaptor/translator/program_translator.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 7555787aa47752..ba296feca0344d 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -435,11 +435,14 @@ pir::Operation* ProgramTranslator::InsertFullOrDataOpToBlock( 0, paddle::dialect::TransToPhiDataType(tensor_type.dtype()), phi::CPUPlace()); + full_op.out().set_type(type); return full_op.operation(); } else if (type.isa()) { auto array_type = type.dyn_cast(); paddle::dialect::CreateArrayOp array_op = - builder.Build(array_type.dtype()); + builder.Build( + paddle::dialect::TransToPhiDataType(array_type.dtype())); + array_op.out().set_type(type); return array_op.operation(); } return nullptr; @@ -523,7 +526,7 @@ void ProgramTranslator::TranslateIfOperation( "Only support insert full or data op for DenseTensor or " "DenseTensorArray to false block failed.")); false_block_context->PushValue( - cond_op_outputs[id], VariableDefiningInfo(init_op->result(i))); + cond_op_outputs[id], VariableDefiningInfo(init_op->result(0))); } false_yeild_inputs.push_back( false_block_context->at(cond_op_outputs[id]).value); From a3349dcbd32cdf47513eac35aaf80afd006b44e9 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 7 Dec 2023 14:10:03 +0000 Subject: [PATCH 15/19] fix --- paddle/fluid/ir_adaptor/translator/op_translator.cc | 8 ++++++-- paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc | 9 --------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index eefe7ebe79db65..ddce2e0c5e39d3 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1928,7 +1928,9 @@ struct SelectInputOpTranscriber : public OpTranscriber { tensor1.offset() != tensor2.offset()) { IR_THROW( "select_input only support same type or DenseTensorType with " - "only different dim."); + "only different dim, but get %s != %s.", + tensor1, + tensor2); } auto dim1 = input1.dyn_cast().dims(); auto dim2 = input2.dyn_cast().dims(); @@ -1944,7 +1946,9 @@ struct SelectInputOpTranscriber : public OpTranscriber { } IR_THROW( "select_input only support same type or DenseTensorType with only " - "different dim."); + "different dim, now is %s != %s.", + input1, + input2); } pir::Operation* operation = pir::Operation::Create( diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index cbe4654c707d49..f19e55ed1519e1 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -173,36 +173,27 @@ void IfOp::VerifyRegion() { 1u, phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", (*this)->region(0).size())); - VLOG(0) << "Start Verifying true branch 1."; auto &true_last_op = (*this)->region(0).front().back(); - VLOG(0) << "Start Verifying true branch 2."; PADDLE_ENFORCE_EQ(true, true_last_op.isa(), phi::errors::PreconditionNotMet( "The last of true block must be YieldOp")); - VLOG(0) << "Start Verifying true branch 3."; PADDLE_ENFORCE_EQ(true_last_op.num_operands(), (*this)->num_results(), phi::errors::PreconditionNotMet( "The size of last of true block op's input must be " "equal to IfOp's outputs num.")); - VLOG(0) << "Start Verifying true branch 4."; - VLOG(0) << "size block: " << (*this)->region(1).size(); - VLOG(0) << "size op: " << (*this)->region(1).front().size(); VLOG(4) << "Start Verifying false branch."; PADDLE_ENFORCE_EQ( (*this)->region(1).size(), 1u, phi::errors::PreconditionNotMet("The size %d of false_region must be 1.", (*this)->region(0).size())); - VLOG(0) << "Start Verifying true branch 5."; auto &false_last_op = (*this)->region(1).front().back(); - VLOG(0) << "Start Verifying true branch 6."; PADDLE_ENFORCE_EQ(true, false_last_op.isa(), phi::errors::PreconditionNotMet( "The last of false block must be YieldOp")); - VLOG(0) << "Start Verifying true branch 7."; PADDLE_ENFORCE_EQ(false_last_op.num_operands(), (*this)->num_results(), phi::errors::PreconditionNotMet( From b172157ff0462e4d9efaedd6c325ce8987e26d81 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 8 Dec 2023 03:37:41 +0000 Subject: [PATCH 16/19] fix --- .../ir_adaptor/translator/op_translator.cc | 52 ++++--- test/dygraph_to_static/test_return.py | 144 +++++++++++++++++- test/dygraph_to_static/test_warning.py | 8 +- test/legacy_test/test_cond.py | 6 +- 4 files changed, 180 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index ddce2e0c5e39d3..4736a3af20a3d4 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1917,33 +1917,39 @@ struct SelectInputOpTranscriber : public OpTranscriber { auto input2 = op_inputs[2].type(); if (input1 == input2) { op_output_types.push_back(op_inputs[1].type()); - } else { - if (input1.isa() && - input2.isa()) { - auto tensor1 = input1.dyn_cast(); - auto tensor2 = input2.dyn_cast(); - if (tensor1.dtype() != tensor2.dtype() || - tensor1.data_layout() != tensor2.data_layout() || - tensor1.lod() != tensor2.lod() || - tensor1.offset() != tensor2.offset()) { - IR_THROW( - "select_input only support same type or DenseTensorType with " - "only different dim, but get %s != %s.", - tensor1, - tensor2); - } - auto dim1 = input1.dyn_cast().dims(); - auto dim2 = input2.dyn_cast().dims(); - std::vector compat_shape = ParseCompatibleShapes( - common::vectorize(dim1), common::vectorize(dim2)); - op_output_types.push_back(paddle::dialect::DenseTensorType::get( - ctx, + } else if (input1.isa() && + input2.isa()) { + auto tensor1 = input1.dyn_cast(); + auto tensor2 = input2.dyn_cast(); + if (tensor1.dtype() != tensor2.dtype() || + tensor1.data_layout() != tensor2.data_layout() || + tensor1.lod() != tensor2.lod() || + tensor1.offset() != tensor2.offset()) { + IR_THROW( + "select_input only support same type or DenseTensorType with " + "only different dim, but get dtype:[%s, %s], layout:[%s, %s], " + "lod:[%s, %s], offset:[%s, %s].", tensor1.dtype(), - common::make_ddim(compat_shape), + tensor2.dtype(), tensor1.data_layout(), + tensor2.data_layout(), tensor1.lod(), - tensor1.offset())); + tensor2.lod(), + tensor1.offset(), + tensor2.offset()); } + auto dim1 = input1.dyn_cast().dims(); + auto dim2 = input2.dyn_cast().dims(); + std::vector compat_shape = ParseCompatibleShapes( + common::vectorize(dim1), common::vectorize(dim2)); + op_output_types.push_back( + paddle::dialect::DenseTensorType::get(ctx, + tensor1.dtype(), + common::make_ddim(compat_shape), + tensor1.data_layout(), + tensor1.lod(), + tensor1.offset())); + } else { IR_THROW( "select_input only support same type or DenseTensorType with only " "different dim, now is %s != %s.", diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index ceab96855c3d43..5d7e02ac8181e9 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_legacy_only, +) from ifelse_simple_func import dyfunc_with_if_else import paddle @@ -316,10 +320,54 @@ def init_dygraph_func(self): self.dygraph_func = test_inside_func_base -class TestReturnIf(TestReturnBase): +class TestReturnIf(Dy2StTestBase): + def setUp(self): + self.input = np.ones(1).astype('int32') + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + self.init_dygraph_func() + def init_dygraph_func(self): self.dygraph_func = test_return_if + def _run(self, to_static=False): + paddle.jit.enable_to_static(to_static) + with base.dygraph.guard(): + res = self.dygraph_func(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res + + def _test_value_impl(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + if isinstance(dygraph_res, tuple): + self.assertTrue(isinstance(static_res, tuple)) + self.assertEqual(len(dygraph_res), len(static_res)) + for i in range(len(dygraph_res)): + np.testing.assert_allclose( + dygraph_res[i], static_res[i], rtol=1e-05 + ) + elif isinstance(dygraph_res, np.ndarray): + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + else: + self.assertEqual(dygraph_res, static_res) + + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only + @test_ast_only + def test_transformed_static_result(self): + if hasattr(self, "error"): + with self.assertRaisesRegex(Dygraph2StaticException, self.error): + self._test_value_impl() + else: + self._test_value_impl() + class TestReturnOnlyIf(TestReturnBase): def init_dygraph_func(self): @@ -331,20 +379,108 @@ def init_dygraph_func(self): self.dygraph_func = test_return_in_for -class TestReturnInWhile(TestReturnBase): +class TestReturnInWhile(Dy2StTestBase): + def setUp(self): + self.input = np.ones(1).astype('int32') + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + self.init_dygraph_func() + def init_dygraph_func(self): self.dygraph_func = test_return_in_while + def _run(self, to_static=False): + paddle.jit.enable_to_static(to_static) + with base.dygraph.guard(): + res = self.dygraph_func(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res + + def _test_value_impl(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + if isinstance(dygraph_res, tuple): + self.assertTrue(isinstance(static_res, tuple)) + self.assertEqual(len(dygraph_res), len(static_res)) + for i in range(len(dygraph_res)): + np.testing.assert_allclose( + dygraph_res[i], static_res[i], rtol=1e-05 + ) + elif isinstance(dygraph_res, np.ndarray): + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + else: + self.assertEqual(dygraph_res, static_res) + + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only + @test_ast_only + def test_transformed_static_result(self): + if hasattr(self, "error"): + with self.assertRaisesRegex(Dygraph2StaticException, self.error): + self._test_value_impl() + else: + self._test_value_impl() + class TestReturnIfDiff(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_diff_return -class TestReturnIfElse(TestReturnBase): +class TestReturnIfElse(Dy2StTestBase): + def setUp(self): + self.input = np.ones(1).astype('int32') + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + self.init_dygraph_func() + def init_dygraph_func(self): self.dygraph_func = test_return_if_else + def _run(self, to_static=False): + paddle.jit.enable_to_static(to_static) + with base.dygraph.guard(): + res = self.dygraph_func(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res + + def _test_value_impl(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + if isinstance(dygraph_res, tuple): + self.assertTrue(isinstance(static_res, tuple)) + self.assertEqual(len(dygraph_res), len(static_res)) + for i in range(len(dygraph_res)): + np.testing.assert_allclose( + dygraph_res[i], static_res[i], rtol=1e-05 + ) + elif isinstance(dygraph_res, np.ndarray): + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + else: + self.assertEqual(dygraph_res, static_res) + + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only + @test_ast_only + def test_transformed_static_result(self): + if hasattr(self, "error"): + with self.assertRaisesRegex(Dygraph2StaticException, self.error): + self._test_value_impl() + else: + self._test_value_impl() + class TestReturnInWhile2(TestReturnBase): def init_dygraph_func(self): diff --git a/test/dygraph_to_static/test_warning.py b/test/dygraph_to_static/test_warning.py index 9eac0f6a8902bb..e1b9a02b2851dd 100644 --- a/test/dygraph_to_static/test_warning.py +++ b/test/dygraph_to_static/test_warning.py @@ -15,7 +15,11 @@ import unittest import warnings -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_legacy_only, +) import paddle from paddle.static.nn import cond @@ -39,6 +43,8 @@ def false_fn(): class TestReturnNoneInIfelse(Dy2StTestBase): + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only @test_ast_only def test_dy2static_warning(self): paddle.disable_static() diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index 9eb3b575c34084..faf9d43c0a42be 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -236,8 +236,10 @@ def true_func(): def false_func(): return paddle.tensor.fill_constant( - shape=[3, 4], dtype='float32', value=3 - ), paddle.tensor.fill_constant(shape=[4, 5], dtype='int64', value=2) + shape=[3, 4], dtype='int32', value=3 + ), paddle.tensor.fill_constant( + shape=[4, 5], dtype='bool', value=False + ) main_program = Program() startup_program = Program() From 59cedeb4848bb27f6bf6acee892b8b9adb507231 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 8 Dec 2023 03:56:57 +0000 Subject: [PATCH 17/19] fix --- test/dygraph_to_static/test_ifelse.py | 57 ++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 88b775af83a1aa..95f35f9c22e6ed 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -22,6 +22,7 @@ disable_test_case, enable_to_static_guard, test_ast_only, + test_legacy_only, ) from ifelse_simple_func import ( NetWithControlFlowIf, @@ -109,11 +110,28 @@ def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) -class TestDygraphIfElse3(TestDygraphIfElse): +class TestDygraphIfElse3(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = dyfunc_with_if_else3 + def _run_static(self): + return self._run_dygraph(to_static=True) + + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = base.dygraph.to_variable(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphIfElse4(TestDygraphIfElse): def setUp(self): @@ -144,6 +162,8 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -154,11 +174,28 @@ def setUp(self): self.dyfunc = nested_if_else_2 -class TestDygraphNestedIfElse3(TestDygraphIfElse): +class TestDygraphNestedIfElse3(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else_3 + def _run_static(self): + return self._run_dygraph(to_static=True) + + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = paddle.to_tensor(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + def dyfunc_ifExp_with_while(x): y = [x] @@ -186,10 +223,10 @@ def body(i, ten, y): return y[0] -class TestDygraphIfElse6(TestDygraphIfElse): - def setUp(self): - self.x = np.random.random([10, 16]).astype('float32') - self.dyfunc = dyfunc_ifExp_with_while +# class TestDygraphIfElse6(TestDygraphIfElse): +# def setUp(self): +# self.x = np.random.random([10, 16]).astype('float32') +# self.dyfunc = dyfunc_ifExp_with_while def dyfunc_ifExp(x): @@ -269,6 +306,8 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -297,6 +336,8 @@ def _run(self, to_static=False): ret = net(x_v) return ret.numpy() + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -461,6 +502,8 @@ def get_dy2stat_out(self): out = static_func(self.x) return out + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only @test_ast_only def test_ast_to_func(self): self.setUp() @@ -481,6 +524,8 @@ def setUp(self): self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3) self.out = self.get_dy2stat_out() + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only @test_ast_only def test_ast_to_func(self): self.setUp() From 1b374725a4bf3d0cead27aa9997ad1760787c4e1 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 8 Dec 2023 06:14:23 +0000 Subject: [PATCH 18/19] fix --- paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc | 2 -- test/cpp/pir/core/program_translator_test.cc | 3 --- 2 files changed, 5 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index f19e55ed1519e1..7172324cf788ac 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -165,8 +165,6 @@ void IfOp::VerifySig() { void IfOp::VerifyRegion() { VLOG(4) << "Start Verifying sub regions for: IfOp."; - (*this)->Print(std::cout); - std::cout << std::endl; VLOG(4) << "Start Verifying true branch."; PADDLE_ENFORCE_EQ( (*this)->region(0).size(), diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index 68baecdc7f73ba..010ed757d2ab8b 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -85,9 +85,6 @@ TEST(OperatorDialectTest, ConditionBlock) { ctx->GetOrRegisterDialect(); auto program = paddle::TranslateLegacyProgramToProgram(p); - program->Print(std::cout); - std::cout << std::endl; - EXPECT_EQ(program->block()->size(), 9u); size_t id = 0; for (auto &op : *program->block()) { From 66dd4d3d7156365a5b4bb3e11c46fccbcabb0bdb Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 8 Dec 2023 08:54:27 +0000 Subject: [PATCH 19/19] fix --- .../dialect/operator/ir/control_flow_op.cc | 44 ++++++++++--------- .../test_program_translator.py | 3 ++ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 7172324cf788ac..dbb7c7c248dd48 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -171,32 +171,36 @@ void IfOp::VerifyRegion() { 1u, phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", (*this)->region(0).size())); - auto &true_last_op = (*this)->region(0).front().back(); - PADDLE_ENFORCE_EQ(true, - true_last_op.isa(), - phi::errors::PreconditionNotMet( - "The last of true block must be YieldOp")); - PADDLE_ENFORCE_EQ(true_last_op.num_operands(), - (*this)->num_results(), - phi::errors::PreconditionNotMet( - "The size of last of true block op's input must be " - "equal to IfOp's outputs num.")); + if ((*this)->region(0).front().size() > 0) { + auto &true_last_op = (*this)->region(0).front().back(); + PADDLE_ENFORCE_EQ(true, + true_last_op.isa(), + phi::errors::PreconditionNotMet( + "The last of true block must be YieldOp")); + PADDLE_ENFORCE_EQ(true_last_op.num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of true block op's input must be " + "equal to IfOp's outputs num.")); + } VLOG(4) << "Start Verifying false branch."; PADDLE_ENFORCE_EQ( (*this)->region(1).size(), 1u, phi::errors::PreconditionNotMet("The size %d of false_region must be 1.", (*this)->region(0).size())); - auto &false_last_op = (*this)->region(1).front().back(); - PADDLE_ENFORCE_EQ(true, - false_last_op.isa(), - phi::errors::PreconditionNotMet( - "The last of false block must be YieldOp")); - PADDLE_ENFORCE_EQ(false_last_op.num_operands(), - (*this)->num_results(), - phi::errors::PreconditionNotMet( - "The size of last of false block op's input must be " - "equal to IfOp's outputs num.")); + if ((*this)->region(1).front().size() > 0) { + auto &false_last_op = (*this)->region(1).front().back(); + PADDLE_ENFORCE_EQ(true, + false_last_op.isa(), + phi::errors::PreconditionNotMet( + "The last of false block must be YieldOp")); + PADDLE_ENFORCE_EQ(false_last_op.num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of false block op's input must be " + "equal to IfOp's outputs num.")); + } } std::vector> IfOp::Vjp( diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index 2e373e5a57b6bd..812dd9d040e747 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -24,6 +24,7 @@ ToStaticMode, disable_test_case, test_ast_only, + test_legacy_only, ) from ifelse_simple_func import ( dyfunc_with_if_else_early_return1, @@ -304,6 +305,8 @@ def test_raise_error(self): class TestIfElseEarlyReturn(Dy2StTestBase): + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)