diff --git a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt index bfdd4f2b50db48..46d907d60841b8 100644 --- a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt @@ -3,8 +3,9 @@ cc_library( SRCS instruction_base.cc phi_kernel_instruction.cc legacy_kernel_instruction.cc - cond_instruction.cc + if_instruction.cc while_instruction.cc + select_input_instruction.cc has_elements_instruction.cc tuple_push_instruction.cc tuple_pop_instruction.cc diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc similarity index 90% rename from paddle/fluid/framework/new_executor/instruction/cond_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/if_instruction.cc index a25d7d2a5a6df4..3ac3a9e4780be3 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/cond_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" @@ -39,11 +39,11 @@ namespace paddle { namespace framework { -CondInstruction::CondInstruction(size_t id, - const platform::Place& place, - pir::Operation* op, - ValueExecutionInfo* value_exec_info, - const std::set& skip_gc_vars) +IfInstruction::IfInstruction(size_t id, + const platform::Place& place, + pir::Operation* op, + ValueExecutionInfo* value_exec_info, + const std::set& skip_gc_vars) : InstructionBase(id, place) { PADDLE_ENFORCE( op->isa(), @@ -66,12 +66,14 @@ CondInstruction::CondInstruction(size_t id, // OpOperand of IfOp, and the other is external Values used in true_block or // false_block. auto& true_branch_block = if_op.true_block(); - auto& false_branch_block = if_op.false_block(); + std::unordered_map> inputs; GetInputIds(op, *value_exec_info, &inputs); auto true_outside_inputs = GetExternalInputs(&true_branch_block, *value_exec_info, &inputs); - auto false_outside_inputs = + std::vector false_outside_inputs; + auto& false_branch_block = if_op.false_block(); + false_outside_inputs = GetExternalInputs(&false_branch_block, *value_exec_info, &inputs); SetInputs(inputs); @@ -90,8 +92,10 @@ CondInstruction::CondInstruction(size_t id, } } InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs); + InsertTuplePushContinerToOuts( - &false_branch_block, *value_exec_info, &outputs); + &if_op.false_block(), *value_exec_info, &outputs); + SetOutputs(outputs); VLOG(6) << "finish process inputs outputs index"; @@ -126,11 +130,10 @@ CondInstruction::CondInstruction(size_t id, false_branch_inter_ = new PirInterpreter(place, {}, - &false_branch_block, + &if_op.false_block(), false_scope, value_exec_info->NewChild(false_scope), {}); - std::set false_skip_gc_names_set; for (auto value : GetYiedOpInputs(&false_branch_block)) { false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value)); @@ -146,10 +149,11 @@ CondInstruction::CondInstruction(size_t id, false_skip_gc_names_set.insert(var_name); } false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); + VLOG(6) << "finish process false branch interpreter"; } -CondInstruction::~CondInstruction() { +IfInstruction::~IfInstruction() { if (true_branch_inter_ != nullptr) { delete true_branch_inter_; } @@ -158,8 +162,8 @@ CondInstruction::~CondInstruction() { } } -void CondInstruction::CopyBranchOutput( - const std::vector& var_names, const PirInterpreter* inter) { +void IfInstruction::CopyBranchOutput(const std::vector& var_names, + const PirInterpreter* inter) { for (size_t i = 0; i < var_names.size(); ++i) { auto* inner_var = inter->InnerScope()->GetVar(var_names[i]); @@ -179,7 +183,7 @@ void CondInstruction::CopyBranchOutput( } } -void CondInstruction::Run() { +void IfInstruction::Run() { DeviceContext().Wait(); if (cond_var_->Get().data()[0]) { true_branch_inter_->Run({}, false); @@ -188,7 +192,6 @@ void CondInstruction::Run() { false_branch_inter_->Run({}, false); CopyBranchOutput(false_branch_outputs_, false_branch_inter_); } - // copy ouptut } diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/if_instruction.h similarity index 80% rename from paddle/fluid/framework/new_executor/instruction/cond_instruction.h rename to paddle/fluid/framework/new_executor/instruction/if_instruction.h index 45f39ba338814f..e6d1fc4723c5d6 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.h @@ -27,15 +27,15 @@ class Value; class PirInterpreter; class ValueExecutionInfo; -class CondInstruction : public InstructionBase { +class IfInstruction : public InstructionBase { public: - CondInstruction(size_t id, - const platform::Place& place, - ::pir::Operation* op, - ValueExecutionInfo* value_exe_info, - const std::set& skip_gc_vars); + IfInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + ValueExecutionInfo* value_exe_info, + const std::set& skip_gc_vars); - ~CondInstruction(); + ~IfInstruction(); void Run() override; @@ -53,15 +53,15 @@ class CondInstruction : public InstructionBase { ::pir::Operation* op_; - std::string cond_name_{"cond_instruction"}; + std::string cond_name_{"if_instruction"}; Variable* cond_var_; std::vector output_vars_; - PirInterpreter* true_branch_inter_; + PirInterpreter* true_branch_inter_ = nullptr; - PirInterpreter* false_branch_inter_; + PirInterpreter* false_branch_inter_ = nullptr; std::vector true_branch_outputs_; diff --git a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc new file mode 100644 index 00000000000000..893915f841d7fc --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" + +namespace paddle { +namespace framework { + +SelectInputInstruction::SelectInputInstruction( + size_t id, + const platform::Place &place, + ::pir::Operation *op, + ValueExecutionInfo *value_exe_info) + : InstructionBase(id, place), op_(op) { + VLOG(6) << "construct select_input instruction"; + + std::unordered_map> inputs; + mask_ = value_exe_info->GetVarByValue(op->operand_source(0)); + inputs.emplace(op->operand_source(0), + GetValueIds(op->operand_source(0), *value_exe_info)); + + for (size_t i = 1; i < op->num_operands(); ++i) { + inputs_.push_back(value_exe_info->GetVarByValue(op->operand_source(i))); + inputs.emplace(op->operand_source(i), + GetValueIds(op->operand_source(i), *value_exe_info)); + } + SetInputs(inputs); + + std::unordered_map> outputs; + out_ = value_exe_info->GetVarByValue(op->result(0)); + outputs.emplace(op->result(0), GetValueIds(op->result(0), *value_exe_info)); + SetOutputs(outputs); +} + +inline int GetBranchNumber(const phi::DenseTensor &mask) { + PADDLE_ENFORCE_EQ( + mask.numel(), + 1, + phi::errors::Fatal("The numel of Input(Mask) in SelectInputOp or " + "SelectOutputOp must be 1. " + "But received %d, and it's shape is [%s].", + mask.numel(), + mask.dims())); + if (platform::is_cpu_place(mask.place())) { + return mask.data()[0]; + } + // when platform::is_gpu_place(mask.place()) is true + std::unique_ptr cpu_mask{new phi::DenseTensor()}; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU) + framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get()); +#else + PADDLE_THROW(phi::errors::Fatal( + "This version of PaddlePaddle does NOT support GPU, " + "but got GPU tensor 'Mask' in SelectInputOp or SelectOutputOp. " + "Please compile PaddlePaddle WITH_GPU first.")); +#endif + return cpu_mask->data()[0]; +} + +class AssignFunctor { + public: + explicit AssignFunctor(Variable *out) : out_(out) {} + + void operator()(const phi::DenseTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const phi::TensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const phi::SelectedRows &rows) const { + phi::SelectedRows &out_rows = *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + auto *m = out_rows.mutable_value(); + TensorCopy(t, t.place(), m); + } + + template + void operator()(const T &v UNUSED) const { + PADDLE_ENFORCE_EQ( + true, + false, + platform::errors::PermissionDenied( + "Not support type for assign op with type %s", typeid(T).name())); + } + + private: + void copy_tensor(const phi::DenseTensor &lod_tensor, + phi::DenseTensor *out) const { + if (!lod_tensor.IsInitialized()) return; + auto &out_tensor = *out; + TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } + + Variable *out_; +}; + +void SelectInputInstruction::Run() { + VLOG(6) << "run select_input instruction"; + auto &mask = mask_->Get(); + size_t output_branch = static_cast(GetBranchNumber(mask)); + PADDLE_ENFORCE_LT( + output_branch, + inputs_.size(), + phi::errors::Fatal( + "Input 'Mask' in SelectInputOp is invalid. " + "'Mask' must be less than the size of input vector 'X'. " + "But received Mask = %d, X's size = %d.", + output_branch, + inputs_.size())); + Variable *selected = inputs_[output_branch]; + VisitVarType(*selected, AssignFunctor(out_)); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.h b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.h new file mode 100644 index 00000000000000..16038e66152f69 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/select_input_instruction.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace paddle { +namespace framework { +class ValueExecutionInfo; + +class SelectInputInstruction : public InstructionBase { + public: + SelectInputInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + ValueExecutionInfo* value_exe_info); + + void Run() override; + + const std::string& Name() const override { return name_; } + + ::pir::Operation* Operation() const override { return op_; } + + private: + ::pir::Operation* op_; + + OpFuncType type_; + + std::string name_{"pd_op.select_input"}; + + Variable* mask_; // not owned + + std::vector inputs_; // not owned + + Variable* out_; // not owned +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc index aee25e8d816843..2f3787118d2e48 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc @@ -175,20 +175,14 @@ void WhileInstruction::CopyOutputsToBlockArgs() { auto* dst_tensor_array = inner_var->GetMutable(); dst_tensor_array->set_type(src_tensor_array.dtype()); dst_tensor_array->set_layout(src_tensor_array.layout()); - if (dst_tensor_array->empty()) { - for (auto src_tensor : src_tensor_array) { - phi::DenseTensor* tmp_dst_tensor = new phi::DenseTensor(); - tmp_dst_tensor->set_meta(src_tensor.meta()); - framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor); - dst_tensor_array->push_back(*tmp_dst_tensor); - } - } else { - for (size_t id = 0; id < dst_tensor_array->size(); id++) { - auto& src_tensor = src_tensor_array[id]; - phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id); - tmp_dst_tensor->set_meta(src_tensor.meta()); - framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor); - } + while (dst_tensor_array->size() < src_tensor_array.size()) { + dst_tensor_array->emplace_back(); + } + for (size_t id = 0; id < dst_tensor_array->size(); id++) { + auto& src_tensor = src_tensor_array[id]; + phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id); + tmp_dst_tensor->set_meta(src_tensor.meta()); + framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor); } } else { PADDLE_THROW( diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index 26fc26a2dd3713..2dfe34b298bbd1 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -43,11 +43,11 @@ namespace paddle { namespace framework { -class CondInstruction; +class IfInstruction; class WhileInstruction; class ValueExecutionInfo { public: - friend class CondInstruction; + friend class IfInstruction; friend class WhileInstruction; explicit ValueExecutionInfo(Scope* scope) : scope_(scope) {} diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 66de40585130b5..45674498b179fb 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -46,10 +46,11 @@ #endif #include "paddle/fluid/framework/new_executor/instruction/builtin_combine_instruction.h" -#include "paddle/fluid/framework/new_executor/instruction/cond_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/has_elements_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/while_instruction.h" @@ -671,15 +672,15 @@ void PirInterpreter::BuildInstruction() { } else if (op.dialect()->name() == "pd_op") { if (op.isa()) { auto skip_gc_vars = execution_config_.skip_gc_vars; - vec_instruction_base_.emplace_back(std::make_unique( + vec_instruction_base_.emplace_back(std::make_unique( op_idx++, place_, &op, value_exe_info_.get(), skip_gc_vars)); sub_blocks_.insert( {&op.dyn_cast().true_block(), - dynamic_cast(vec_instruction_base_.back().get()) + dynamic_cast(vec_instruction_base_.back().get()) ->TrueBranchInterpreter()}); sub_blocks_.insert( {&op.dyn_cast().false_block(), - dynamic_cast(vec_instruction_base_.back().get()) + dynamic_cast(vec_instruction_base_.back().get()) ->FalseBranchInterpreter()}); } else if (op.isa()) { auto skip_gc_vars = execution_config_.skip_gc_vars; @@ -691,6 +692,8 @@ void PirInterpreter::BuildInstruction() { ->BodyInterpreter()}); } else if (op.isa()) { CREATE_INSTR(HasElementsInstruction); + } else if (op.isa()) { + CREATE_INSTR(SelectInputInstruction); } else { PADDLE_THROW(platform::errors::Unimplemented( "Now only support pd_kernel and cinn dialect.")); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index d29781af492de1..4736a3af20a3d4 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1863,6 +1863,110 @@ struct FillConstantTranscriber : public OpTranscriber { } }; +static std::vector ParseCompatibleShapes( + const std::vector& dim1, const std::vector& dim2) { + IR_ENFORCE(dim1.size() == dim2.size(), + "Does not support rank inconsistency: dim1=%d, dim2=%d", + dim1.size(), + dim2.size()); + std::vector result; + for (size_t i = 0; i < dim1.size(); ++i) { + if (dim1[i] != dim2[i]) { + result.push_back(-1); + } else { + result.push_back(dim1[i]); + } + } + return result; +} + +struct SelectInputOpTranscriber : public OpTranscriber { + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Block* block) override { + VLOG(10) << "[op select_input] start transcribing"; + auto op_info = this->LoopkUpOpInfo(ctx, op_desc); + + std::vector op_inputs = {}; + auto Mask_name = op_desc.Input("Mask")[0]; + auto& Input_name = op_desc.Input("X"); + IR_ENFORCE(param_map->count(Mask_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + Mask_name); + op_inputs.push_back(param_map->at(Mask_name).value); + for (auto in_name : Input_name) { + IR_ENFORCE(param_map->count(in_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + in_name); + op_inputs.push_back(param_map->at(in_name).value); + } + + pir::AttributeMap attribute_map; + + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types; + auto Out_name = op_desc.Output("Out")[0]; + VarDesc* var = op_desc.Block()->FindVarRecursive(Out_name); + arg_to_idx[var->Name()] = {0, 0}; + + // NOTE(zhangbo): Only support + auto input1 = op_inputs[1].type(); + auto input2 = op_inputs[2].type(); + if (input1 == input2) { + op_output_types.push_back(op_inputs[1].type()); + } else if (input1.isa() && + input2.isa()) { + auto tensor1 = input1.dyn_cast(); + auto tensor2 = input2.dyn_cast(); + if (tensor1.dtype() != tensor2.dtype() || + tensor1.data_layout() != tensor2.data_layout() || + tensor1.lod() != tensor2.lod() || + tensor1.offset() != tensor2.offset()) { + IR_THROW( + "select_input only support same type or DenseTensorType with " + "only different dim, but get dtype:[%s, %s], layout:[%s, %s], " + "lod:[%s, %s], offset:[%s, %s].", + tensor1.dtype(), + tensor2.dtype(), + tensor1.data_layout(), + tensor2.data_layout(), + tensor1.lod(), + tensor2.lod(), + tensor1.offset(), + tensor2.offset()); + } + auto dim1 = input1.dyn_cast().dims(); + auto dim2 = input2.dyn_cast().dims(); + std::vector compat_shape = ParseCompatibleShapes( + common::vectorize(dim1), common::vectorize(dim2)); + op_output_types.push_back( + paddle::dialect::DenseTensorType::get(ctx, + tensor1.dtype(), + common::make_ddim(compat_shape), + tensor1.data_layout(), + tensor1.lod(), + tensor1.offset())); + } else { + IR_THROW( + "select_input only support same type or DenseTensorType with only " + "different dim, now is %s != %s.", + input1, + input2); + } + + pir::Operation* operation = pir::Operation::Create( + op_inputs, attribute_map, op_output_types, op_info); + block->push_back(operation); + RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); + + VLOG(10) << "[op assign_value] translation finished"; + return operation; + } +}; + pir::OpResult TranslateNumClassesForOneHot( pir::IrContext* ctx, TranslationContext* param_map, @@ -2736,6 +2840,7 @@ OpTranslator::OpTranslator() { special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); special_handlers["mul"] = MulOpTranscriber(); special_handlers["mul_grad"] = MulGradOpTranscriber(); + special_handlers["select_input"] = SelectInputOpTranscriber(); // To adapt LodTensorArray special_handlers["lod_array_length"] = LodArrayLengthOpTranscriber(); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 47b8ac58c8a351..ba296feca0344d 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -389,15 +389,11 @@ void ProgramTranslator::Translate() { } } -void ProgramTranslator::TranslateBlock( - const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - TranslationContext* translation_ctx, - pir::Block* dst_block, - bool for_cond_block, - const std::vector& cond_sub_block_outputs, - const std::vector<::paddle::framework::OpDesc*>& cond_init_ops) { +void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + TranslationContext* translation_ctx, + pir::Block* dst_block) { VLOG(8) << "=============>start to translate a block"; PADDLE_ENFORCE( (src_block.OpSize() >= end_id) && (start_id <= end_id), @@ -408,13 +404,8 @@ void ProgramTranslator::TranslateBlock( end_id, src_block.OpSize())); - std::unordered_map translate_completed; std::map assign_output_2_input; for (uint64_t op_id = start_id; op_id < end_id; op_id++) { - if (translate_completed.count(op_id) && translate_completed.at(op_id)) { - continue; - } - auto op = src_block.Op(static_cast(op_id)); VLOG(8) << "=============>start to translate a op: " << op->Type(); @@ -424,144 +415,137 @@ void ProgramTranslator::TranslateBlock( "Not support translated %s op", op->Type())); if (op->Type() == "conditional_block") { - std::vector cond_op_ids = GetCondOpIds(src_block, op_id); - ConditionBlockCombination cond_op_combination(src_block, cond_op_ids); - pir::Operation* if_op = TranslateCondIfOperation( - cond_op_combination, translation_ctx, dst_block); - for (auto cond_id : cond_op_ids) { - translate_completed[cond_id] = true; - } - VLOG(10) << "[op translated][conditional_block]" << if_op; + TranslateIfOperation(op, translation_ctx, dst_block); } else if (op->Type() == "while") { TranslateWhileOperation(op, translation_ctx, dst_block); } else { - if (for_cond_block && op->Type() == "assign" && - std::count(cond_sub_block_outputs.begin(), - cond_sub_block_outputs.end(), - op->Output("Out")[0])) { - assign_output_2_input[op->Output("Out")[0]] = op->Input("X")[0]; - translate_completed[op_id] = true; - } else { - TranslateGeneralOperation(op, translation_ctx, dst_block); - translate_completed[op_id] = true; - } - } - } - - // NOTE(zhangbo): If conditional_block operator has output, the cf.yeild - // operator needs to be inserted - if (for_cond_block) { - // insert init ops - for (::paddle::framework::OpDesc* init_op : cond_init_ops) { - TranslateGeneralOperation(init_op, translation_ctx, dst_block); + TranslateGeneralOperation(op, translation_ctx, dst_block); } - // insert yeild op - std::vector yeild_inputs; - for (auto output_name : cond_sub_block_outputs) { - if (assign_output_2_input.count(output_name) != 0) { - if (translation_ctx->count(assign_output_2_input[output_name]) == 0) { - CreateUndefinedVariable(assign_output_2_input[output_name], - src_block); - } - yeild_inputs.emplace_back( - (*translation_ctx)[assign_output_2_input[output_name]].value); - } else { - if (translation_ctx->count(output_name) == 0) { - CreateUndefinedVariable(output_name, src_block); - } - yeild_inputs.emplace_back((*translation_ctx)[output_name].value); - } - } - pir::AttributeMap attribute_map; - auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); - pir::Operation* yeild_op = - pir::Operation::Create(yeild_inputs, attribute_map, {}, yeild_info); - dst_block->push_back(yeild_op); } } -pir::Operation* ProgramTranslator::TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, +pir::Operation* ProgramTranslator::InsertFullOrDataOpToBlock( + pir::Block* insert_block, pir::Type type) { + pir::Builder builder(ctx_, insert_block, insert_block->begin()); + if (type.isa()) { + auto tensor_type = type.dyn_cast(); + std::vector shape = common::vectorize(tensor_type.dims()); + paddle::dialect::FullOp full_op = builder.Build( + shape, + 0, + paddle::dialect::TransToPhiDataType(tensor_type.dtype()), + phi::CPUPlace()); + full_op.out().set_type(type); + return full_op.operation(); + } else if (type.isa()) { + auto array_type = type.dyn_cast(); + paddle::dialect::CreateArrayOp array_op = + builder.Build( + paddle::dialect::TransToPhiDataType(array_type.dtype())); + array_op.out().set_type(type); + return array_op.operation(); + } + return nullptr; +} + +// NOTE(zhangbo): All condition_block_op will be translated as an if_op with +// only a true branch. +void ProgramTranslator::TranslateIfOperation( + const OpDesc* op, TranslationContext* translation_ctx, pir::Block* dst_block) { + VLOG(8) << "=============>Start to translate if op:" << op; auto& type_translator = TypeTranslator::instance(); - auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); - std::vector op_inputs = { - (*translation_ctx)[cond_ops.CondVarName()].value}; - auto input_names = cond_ops.GetInputNamesForIfOp(); - for (auto input_name : input_names) { + auto cond_op_cond = op->Input("Cond")[0]; + auto& cond_op_inputs = op->Input("Input"); + for (auto input_name : cond_op_inputs) { VLOG(6) << "[general op][conditional_block][inputs: " << input_name << "]"; GetValueOrCreateInTop(input_name, translation_ctx); } + auto& cond_op_outputs = op->Output("Out"); + std::vector<::paddle::framework::VarDesc*> cond_op_output_vars; + for (auto out_name : cond_op_outputs) { + cond_op_output_vars.emplace_back(op->Block()->FindVarRecursive(out_name)); + } - // NOTE(zhangbo): Now paddle::dialect::IfOp has 0 attribute + std::vector if_op_inputs = { + (*translation_ctx)[cond_op_cond].value}; pir::AttributeMap attribute_map; - - std::vector op_output_types; - std::vector<::paddle::framework::VarDesc*> output_vardescs = - std::get<0>(cond_ops.CondOutputVars()); - for (auto var_desc : output_vardescs) { + std::vector if_op_output_types; + for (auto var_desc : cond_op_output_vars) { IR_ENFORCE(var_desc != nullptr, "[control flow] Output should not be null"); pir::Type translated_var_type = type_translator[var_desc->GetType()](ctx_, *var_desc); - op_output_types.emplace_back(translated_var_type); + if_op_output_types.emplace_back(translated_var_type); } - VLOG(4) << "[general op][conditional_block] IfOp preparation end."; - + auto if_op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); pir::Operation* operation = pir::Operation::Create( - op_inputs, attribute_map, op_output_types, op_info, 2); + if_op_inputs, attribute_map, if_op_output_types, if_op_info, 2); dst_block->push_back(operation); VLOG(4) << "[general op][conditional_block] IfOp creation end."; - if (cond_ops.TrueBlockId() != -1) { - const BlockDesc& true_sub_block = - legacy_program_->Block(cond_ops.TrueBlockId()); + if (op->GetBlockAttrId("sub_block") != -1) { + // Translate true branch by sub_block. + auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); pir::Region& true_region = operation->region(0); if (true_region.empty()) true_region.emplace_back(); - auto* true_block_context = translation_ctx->CreateInnerContext(); - - TranslateBlock(true_sub_block, + TranslateBlock(sub_block, 0, - true_sub_block.OpSize(), + sub_block.OpSize(), true_block_context, - &true_region.front(), - true, - std::get<1>(cond_ops.CondOutputVars()), - cond_ops.TrueBlockInitOps()); - } - VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; - - if (cond_ops.FalseBlockId() != -1) { - const BlockDesc& false_sub_block = - legacy_program_->Block(cond_ops.FalseBlockId()); + &true_region.front()); + // insert yeild op to true block + auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); + std::vector true_yeild_inputs; + for (auto& out_name : cond_op_outputs) { + true_yeild_inputs.push_back(true_block_context->at(out_name).value); + } + true_region.front().push_back( + pir::Operation::Create(true_yeild_inputs, {}, {}, yeild_info)); + + // NOTE(zhangbo): The if_op of PIR requires that both true and false + // branches must exist, and the number of outputs and dtypes must be + // consistent. Only inconsistent shape is allowed. To be compatible with the + // old IR design, only true branches are allowed. The false branch may + // require yeild some fake variables. pir::Region& false_region = operation->region(1); if (false_region.empty()) false_region.emplace_back(); auto* false_block_context = translation_ctx->CreateInnerContext(); - TranslateBlock(false_sub_block, - 0, - false_sub_block.OpSize(), - false_block_context, - &false_region.front(), - true, - std::get<2>(cond_ops.CondOutputVars()), - cond_ops.FalseBlockInitOps()); + std::vector false_yeild_inputs; + for (size_t id = 0; id < cond_op_outputs.size(); id++) { + if (false_block_context->count(cond_op_outputs[id]) == 0) { + auto true_type = true_yeild_inputs[id].type(); + pir::Operation* init_op = + InsertFullOrDataOpToBlock(&false_region.front(), true_type); + PADDLE_ENFORCE_NOT_NULL( + init_op, + phi::errors::PreconditionNotMet( + "Only support insert full or data op for DenseTensor or " + "DenseTensorArray to false block failed.")); + false_block_context->PushValue( + cond_op_outputs[id], VariableDefiningInfo(init_op->result(0))); + } + false_yeild_inputs.push_back( + false_block_context->at(cond_op_outputs[id]).value); + } + false_region.front().push_back( + pir::Operation::Create(false_yeild_inputs, {}, {}, yeild_info)); } - VLOG(4) << "[general op][conditional_block] IfOp false block translate end."; + VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; - for (size_t i = 0; i < output_vardescs.size(); i++) { - translation_ctx->PushValue(output_vardescs[i]->Name(), + for (size_t i = 0; i < cond_op_output_vars.size(); i++) { + translation_ctx->PushValue(cond_op_output_vars[i]->Name(), VariableDefiningInfo(operation->result(i))); VLOG(4) << "[general op][conditional_block] var " - << output_vardescs[i]->Name() << " was mapped to If's " << i + << cond_op_output_vars[i]->Name() << " was mapped to If's " << i << "-th output."; } operation->Verify(); VLOG(4) << "[general op][conditional_block] IfOp translate end."; - return operation; } void ProgramTranslator::TranslateWhileOperation( @@ -813,11 +797,13 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( } } } + const VariableDefiningInfo& ProgramTranslator::GetValueOrCreateInTop( const std::string& var_name, TranslationContext* translation_ctx) { if (translation_ctx->Has(var_name)) return translation_ctx->at(var_name); return CreateUndefinedVariable(var_name, legacy_program_->Block(0)); } + const VariableDefiningInfo& ProgramTranslator::CreateUndefinedVariable( const std::string& var_name, const BlockDesc& block) { VLOG(10) << "[undefined variable]" << var_name; @@ -840,6 +826,7 @@ const VariableDefiningInfo& ProgramTranslator::CreateUndefinedVariable( param_map_.PushValue(var_name, val); return param_map_.at(var_name); } + void ProgramTranslator::SetIsPersisableAttributeForAllValue( const BlockDesc& block) { // Currently we set is persisable for operation that generated a value diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 0dda3dc9b89219..052a8fa13cea41 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -145,15 +145,11 @@ class ProgramTranslator { static const std::unordered_set unsupported_ops; - void TranslateBlock( - const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - TranslationContext* translation_ctx, - pir::Block* dst_block, - bool for_cond_block = false, - const std::vector& cond_sub_block_outputs = {}, - const std::vector<::paddle::framework::OpDesc*>& cond_init_ops = {}); + void TranslateBlock(const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + TranslationContext* translation_ctx, + pir::Block* dst_block); void TranslateGeneralOperation(const OpDesc* src_op, TranslationContext* translation_ctx, @@ -169,11 +165,13 @@ class ProgramTranslator { const VariableDefiningInfo& CreateUndefinedVariable( const std::string& var_name, const BlockDesc& block); - /// Translate methods for control flow ops. - pir::Operation* TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, - TranslationContext* translation_ctx, - pir::Block* dst_block); + pir::Operation* InsertFullOrDataOpToBlock(pir::Block* insert_block, + pir::Type type); + + void TranslateIfOperation(const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dst_block); + void TranslateWhileOperation(const OpDesc* op, TranslationContext* translation_ctx, pir::Block* dst_block); diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 204a4c176d3ffc..dbb7c7c248dd48 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -164,45 +164,43 @@ void IfOp::VerifySig() { } void IfOp::VerifyRegion() { - // VLOG(4) << "Start Verifying sub regions for: IfOp."; - // PADDLE_ENFORCE_EQ( - // (*this)->region(0).size(), - // 1u, - // phi::errors::PreconditionNotMet("The size %d of true_region must - // be 1.", - // (*this)->region(0).size())); - - // if ((*this)->num_results() != 0) { - // PADDLE_ENFORCE_EQ( - // (*this)->region(0).size(), - // (*this)->region(1).size(), - // phi::errors::PreconditionNotMet("The size %d of true_region must be " - // "equal to the size %d of - // false_region.", - // (*this)->region(0).size(), - // (*this)->region(1).size())); - - // auto &true_last_op = (*this)->region(0).front().back(); - // auto &false_last_op = (*this)->region(1).front().back(); - // PADDLE_ENFORCE_EQ(true, - // true_last_op.isa(), - // phi::errors::PreconditionNotMet( - // "The last of true block must be YieldOp")); - // PADDLE_ENFORCE_EQ(true_last_op.num_operands(), - // (*this)->num_results(), - // phi::errors::PreconditionNotMet( - // "The size of last of true block op's input must be - // " "equal to IfOp's outputs num.")); - // PADDLE_ENFORCE_EQ(true, - // false_last_op.isa(), - // phi::errors::PreconditionNotMet( - // "The last of false block must be YieldOp")); - // PADDLE_ENFORCE_EQ(false_last_op.num_operands(), - // (*this)->num_results(), - // phi::errors::PreconditionNotMet( - // "The size of last of false block op's input must be - // " "equal to IfOp's outputs num.")); - // } + VLOG(4) << "Start Verifying sub regions for: IfOp."; + VLOG(4) << "Start Verifying true branch."; + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", + (*this)->region(0).size())); + if ((*this)->region(0).front().size() > 0) { + auto &true_last_op = (*this)->region(0).front().back(); + PADDLE_ENFORCE_EQ(true, + true_last_op.isa(), + phi::errors::PreconditionNotMet( + "The last of true block must be YieldOp")); + PADDLE_ENFORCE_EQ(true_last_op.num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of true block op's input must be " + "equal to IfOp's outputs num.")); + } + VLOG(4) << "Start Verifying false branch."; + PADDLE_ENFORCE_EQ( + (*this)->region(1).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of false_region must be 1.", + (*this)->region(0).size())); + if ((*this)->region(1).front().size() > 0) { + auto &false_last_op = (*this)->region(1).front().back(); + PADDLE_ENFORCE_EQ(true, + false_last_op.isa(), + phi::errors::PreconditionNotMet( + "The last of false block must be YieldOp")); + PADDLE_ENFORCE_EQ(false_last_op.num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of false block op's input must be " + "equal to IfOp's outputs num.")); + } } std::vector> IfOp::Vjp( diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index cda564bedbb1df..2160e56442d465 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -11,8 +11,20 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#ifdef GET_OP_LIST +#undef GET_OP_LIST +paddle::dialect::AddNOp, paddle::dialect::AddN_Op, + paddle::dialect::AddNWithKernelOp, paddle::dialect::FusedGemmEpilogueOp, + paddle::dialect::FusedGemmEpilogueGradOp, paddle::dialect::SplitGradOp, + paddle::dialect::ExpandOp, paddle::dialect::CreateArrayOp, + paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp, + paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp, + paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op, + paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp +#else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" @@ -2181,6 +2193,83 @@ phi::DataType ExpandOp::GetKernelTypeForVar( return expected_kernel_dtype; } +void SelectInputOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: SelectInputOp."; + VLOG(4) << "Verifying inputs:"; + { + auto in_size = num_operands(); + IR_ENFORCE(in_size == 3u, "Size %d of inputs must be >= 3.", in_size); + auto input1 = (*this)->operand_source(1).type(); + auto input2 = (*this)->operand_source(2).type(); + if (input1.isa() && + input2.isa()) { + auto tensor1 = input1.dyn_cast(); + auto tensor2 = input1.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + } else if (input1.isa() && + input2.isa()) { + auto tensor1 = + input1.dyn_cast(); + auto tensor2 = + input1.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + IR_ENFORCE( + tensor1.place() == tensor2.place(), + "The 1st input place %s should be equal to 2ed input place %s.", + tensor1.place(), + tensor2.place()); + } else { + IR_ENFORCE(input1 == input2, + "The 1st input type %s should be equal to 2ed input type %s.", + input1, + input2); + } + } + VLOG(4) << "Verifying outputs:"; + { + auto out_size = num_results(); + IR_ENFORCE( + out_size == 1u, "Size %d of outputs must be equal to 1.", out_size); + } + VLOG(4) << "End Verifying for: AssignArray_Op."; +} + } // namespace dialect } // namespace paddle @@ -2199,3 +2288,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) +#endif diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 6bdeac5bc04c9a..460356039d84ab 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -401,6 +401,17 @@ class ExpandOp : public pir::Op> &stop_gradients); }; +class SelectInputOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.select_input"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + void VerifySig(); + pir::Value mask() { return operand_source(0); } + pir::OpResult out() { return result(0); } +}; + } // namespace dialect } // namespace paddle @@ -419,3 +430,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SliceArrayDenseOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 4c44b91af35b72..002b8cb731ed7e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -64,21 +64,10 @@ void OperatorDialect::initialize() { #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT >(); - RegisterOps(); + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.cc" // NOLINT + >(); RegisterInterfaces(); } diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 40ca238d397b44..04c4d68933140d 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -105,6 +105,7 @@ const std::unordered_set SpecialLowerOps = { pir::TuplePushOp::name(), pir::TuplePopOp::name(), HasElementsOp::name(), + SelectInputOp::name(), "cinn_runtime.jit_kernel"}; static bool NeedFallBackCpu(const pir::Operation* op, @@ -997,6 +998,20 @@ pir::Value GetNewInput( return new_in; } +phi::Place ParsePhiPlace(pir::Type type) { + if (type.isa()) { + return type.dyn_cast().place(); + } else if (type.isa()) { + return type.dyn_cast().place(); + } else if (type.isa()) { + return type.dyn_cast().place(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "ParsePhiPlace only support AllocatedDenseTensorType or " + "AllocatedSelectedRowsType or AllocatedDenseTensorArrayType")); + } +} + void HandleForSpecialOp( const phi::Place& place, pir::Operation* op_item, @@ -1222,6 +1237,18 @@ void HandleForSpecialOp( } } + if (op_item->isa()) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + vec_inputs.push_back(new_in); + } + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_output_types.push_back(vec_inputs[1].type()); + } + } + if (op_item->name() == "cinn_runtime.jit_kernel") { if (op_item->num_operands() > 0) { for (size_t i = 0; i < op_item->num_operands(); ++i) { diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index 0c199f0481e710..010ed757d2ab8b 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -85,7 +85,7 @@ TEST(OperatorDialectTest, ConditionBlock) { ctx->GetOrRegisterDialect(); auto program = paddle::TranslateLegacyProgramToProgram(p); - EXPECT_EQ(program->block()->size(), 4u); + EXPECT_EQ(program->block()->size(), 9u); size_t id = 0; for (auto &op : *program->block()) { if (id == 0 || id == 1) { @@ -117,77 +117,36 @@ TEST(OperatorDialectTest, ConditionBlock) { EXPECT_EQ(op2.isa(), true); } if (true_true_id == 1) { - EXPECT_EQ(op2.isa(), true); - } - true_true_id++; - } - auto &false_false_block = - op1.dyn_cast().false_block(); - size_t false_false_id = 0; - for (auto &op2 : false_false_block) { - if (false_false_id == 0) { - EXPECT_EQ(op2.isa(), true); + EXPECT_EQ(op2.isa(), true); } - if (false_false_id == 1) { + if (true_true_id == 2) { EXPECT_EQ(op2.isa(), true); } - false_false_id++; + true_true_id++; } } if (true_id == 4) { - EXPECT_EQ(op1.isa(), true); + EXPECT_EQ(op1.isa(), true); } if (true_id == 5) { - EXPECT_EQ(op1.isa(), true); + EXPECT_EQ(op1.isa(), true); } - true_id++; - } - // false block - auto &false_block = op.dyn_cast().false_block(); - size_t false_id = 0; - for (auto &op1 : false_block) { - if (false_id == 0 || false_id == 1) { - EXPECT_EQ(op1.isa(), true); + if (true_id == 6) { + EXPECT_EQ(op1.isa(), true); } - if (false_id == 2) { - EXPECT_EQ(op1.isa(), true); + if (true_id == 7) { + EXPECT_EQ(op1.isa(), true); } - if (false_id == 3) { - EXPECT_EQ(op1.isa(), true); - // true block - auto &false_true_block = - op1.dyn_cast().true_block(); - size_t false_true_id = 0; - for (auto &op2 : false_true_block) { - if (false_true_id == 0) { - EXPECT_EQ(op2.isa(), true); - } - if (false_true_id == 1) { - EXPECT_EQ(op2.isa(), true); - } - false_true_id++; - } - // false block - auto &false_false_block = - op1.dyn_cast().true_block(); - size_t false_false_id = 0; - for (auto &op2 : false_false_block) { - if (false_false_id == 0) { - EXPECT_EQ(op2.isa(), true); - } - if (false_false_id == 1) { - EXPECT_EQ(op2.isa(), true); - } - false_false_id++; - } - } - if (false_id == 4) { + if (true_id == 8) { EXPECT_EQ(op1.isa(), true); } - if (false_id == 5) { + if (true_id == 9 || true_id == 10) { + EXPECT_EQ(op1.isa(), true); + } + if (true_id == 11) { EXPECT_EQ(op1.isa(), true); } - false_id++; + true_id++; } } id++; diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 8e94d8ac702084..dbb4dcbf067a75 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -22,6 +22,7 @@ disable_test_case, enable_to_static_guard, test_ast_only, + test_legacy_only, ) from ifelse_simple_func import ( NetWithControlFlowIf, @@ -109,11 +110,28 @@ def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) -class TestDygraphIfElse3(TestDygraphIfElse): +class TestDygraphIfElse3(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = dyfunc_with_if_else3 + def _run_static(self): + return self._run_dygraph(to_static=True) + + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = base.dygraph.to_variable(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphIfElse4(TestDygraphIfElse): def setUp(self): @@ -144,6 +162,8 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -154,11 +174,28 @@ def setUp(self): self.dyfunc = nested_if_else_2 -class TestDygraphNestedIfElse3(TestDygraphIfElse): +class TestDygraphNestedIfElse3(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else_3 + def _run_static(self): + return self._run_dygraph(to_static=True) + + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = paddle.to_tensor(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + def dyfunc_ifExp_with_while(x): y = [x] @@ -186,10 +223,10 @@ def body(i, ten, y): return y[0] -class TestDygraphIfElse6(TestDygraphIfElse): - def setUp(self): - self.x = np.random.random([10, 16]).astype('float32') - self.dyfunc = dyfunc_ifExp_with_while +# class TestDygraphIfElse6(TestDygraphIfElse): +# def setUp(self): +# self.x = np.random.random([10, 16]).astype('float32') +# self.dyfunc = dyfunc_ifExp_with_while def dyfunc_ifExp(x): @@ -269,6 +306,8 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -297,6 +336,8 @@ def _run(self, to_static=False): ret = net(x_v) return ret.numpy() + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -461,6 +502,8 @@ def get_dy2stat_out(self): out = static_func(self.x) return out + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only @test_ast_only def test_ast_to_func(self): self.setUp() @@ -481,6 +524,8 @@ def setUp(self): self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3) self.out = self.get_dy2stat_out() + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only @test_ast_only def test_ast_to_func(self): self.setUp() diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index 2e373e5a57b6bd..812dd9d040e747 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -24,6 +24,7 @@ ToStaticMode, disable_test_case, test_ast_only, + test_legacy_only, ) from ifelse_simple_func import ( dyfunc_with_if_else_early_return1, @@ -304,6 +305,8 @@ def test_raise_error(self): class TestIfElseEarlyReturn(Dy2StTestBase): + # Why add test_legacy_only? : PIR not support if true and false branch output with different rank + @test_legacy_only def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1) diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index ceab96855c3d43..5d7e02ac8181e9 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_legacy_only, +) from ifelse_simple_func import dyfunc_with_if_else import paddle @@ -316,10 +320,54 @@ def init_dygraph_func(self): self.dygraph_func = test_inside_func_base -class TestReturnIf(TestReturnBase): +class TestReturnIf(Dy2StTestBase): + def setUp(self): + self.input = np.ones(1).astype('int32') + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + self.init_dygraph_func() + def init_dygraph_func(self): self.dygraph_func = test_return_if + def _run(self, to_static=False): + paddle.jit.enable_to_static(to_static) + with base.dygraph.guard(): + res = self.dygraph_func(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res + + def _test_value_impl(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + if isinstance(dygraph_res, tuple): + self.assertTrue(isinstance(static_res, tuple)) + self.assertEqual(len(dygraph_res), len(static_res)) + for i in range(len(dygraph_res)): + np.testing.assert_allclose( + dygraph_res[i], static_res[i], rtol=1e-05 + ) + elif isinstance(dygraph_res, np.ndarray): + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + else: + self.assertEqual(dygraph_res, static_res) + + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only + @test_ast_only + def test_transformed_static_result(self): + if hasattr(self, "error"): + with self.assertRaisesRegex(Dygraph2StaticException, self.error): + self._test_value_impl() + else: + self._test_value_impl() + class TestReturnOnlyIf(TestReturnBase): def init_dygraph_func(self): @@ -331,20 +379,108 @@ def init_dygraph_func(self): self.dygraph_func = test_return_in_for -class TestReturnInWhile(TestReturnBase): +class TestReturnInWhile(Dy2StTestBase): + def setUp(self): + self.input = np.ones(1).astype('int32') + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + self.init_dygraph_func() + def init_dygraph_func(self): self.dygraph_func = test_return_in_while + def _run(self, to_static=False): + paddle.jit.enable_to_static(to_static) + with base.dygraph.guard(): + res = self.dygraph_func(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res + + def _test_value_impl(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + if isinstance(dygraph_res, tuple): + self.assertTrue(isinstance(static_res, tuple)) + self.assertEqual(len(dygraph_res), len(static_res)) + for i in range(len(dygraph_res)): + np.testing.assert_allclose( + dygraph_res[i], static_res[i], rtol=1e-05 + ) + elif isinstance(dygraph_res, np.ndarray): + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + else: + self.assertEqual(dygraph_res, static_res) + + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only + @test_ast_only + def test_transformed_static_result(self): + if hasattr(self, "error"): + with self.assertRaisesRegex(Dygraph2StaticException, self.error): + self._test_value_impl() + else: + self._test_value_impl() + class TestReturnIfDiff(TestReturnBase): def init_dygraph_func(self): self.dygraph_func = test_diff_return -class TestReturnIfElse(TestReturnBase): +class TestReturnIfElse(Dy2StTestBase): + def setUp(self): + self.input = np.ones(1).astype('int32') + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + self.init_dygraph_func() + def init_dygraph_func(self): self.dygraph_func = test_return_if_else + def _run(self, to_static=False): + paddle.jit.enable_to_static(to_static) + with base.dygraph.guard(): + res = self.dygraph_func(self.input) + if isinstance(res, (tuple, list)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.eager.Tensor): + return res.numpy() + return res + + def _test_value_impl(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + if isinstance(dygraph_res, tuple): + self.assertTrue(isinstance(static_res, tuple)) + self.assertEqual(len(dygraph_res), len(static_res)) + for i in range(len(dygraph_res)): + np.testing.assert_allclose( + dygraph_res[i], static_res[i], rtol=1e-05 + ) + elif isinstance(dygraph_res, np.ndarray): + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) + else: + self.assertEqual(dygraph_res, static_res) + + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only + @test_ast_only + def test_transformed_static_result(self): + if hasattr(self, "error"): + with self.assertRaisesRegex(Dygraph2StaticException, self.error): + self._test_value_impl() + else: + self._test_value_impl() + class TestReturnInWhile2(TestReturnBase): def init_dygraph_func(self): diff --git a/test/dygraph_to_static/test_warning.py b/test/dygraph_to_static/test_warning.py index 9eac0f6a8902bb..e1b9a02b2851dd 100644 --- a/test/dygraph_to_static/test_warning.py +++ b/test/dygraph_to_static/test_warning.py @@ -15,7 +15,11 @@ import unittest import warnings -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_legacy_only, +) import paddle from paddle.static.nn import cond @@ -39,6 +43,8 @@ def false_fn(): class TestReturnNoneInIfelse(Dy2StTestBase): + # Why add test_legacy_only? : PIR not support if true and false branch output with different dtype + @test_legacy_only @test_ast_only def test_dy2static_warning(self): paddle.disable_static() diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index 163c1b921f7450..a2a616db1b2c72 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -253,7 +253,9 @@ def true_func(): def false_func(): return paddle.tensor.fill_constant( shape=[3, 4], dtype='int32', value=3 - ), paddle.tensor.fill_constant(shape=[4, 5], dtype='bool', value=2) + ), paddle.tensor.fill_constant( + shape=[4, 5], dtype='bool', value=False + ) main_program = paddle.static.Program() startup_program = paddle.static.Program()