diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.cc new file mode 100644 index 00000000000000..4f5bfa06568641 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" + +namespace paddle { +namespace framework { + +SelectOutputInstruction::SelectOutputInstruction( + size_t id, + const platform::Place &place, + ::pir::Operation *op, + ValueExecutionInfo *value_exe_info) + : InstructionBase(id, place), op_(op) { + VLOG(6) << "construct select_output instruction"; + + std::unordered_map> inputs; + mask_ = value_exe_info->GetVarByValue(op->operand_source(0)); + inputs.emplace(op->operand_source(0), + GetValueIds(op->operand_source(0), *value_exe_info)); + input_ = value_exe_info->GetVarByValue(op->operand_source(1)); + inputs.emplace(op->operand_source(1), + GetValueIds(op->operand_source(1), *value_exe_info)); + SetInputs(inputs); + + std::unordered_map> outputs; + for (size_t i = 0; i < op->num_results(); ++i) { + outputs_.push_back(value_exe_info->GetVarByValue(op->result(i))); + outputs.emplace(op->result(i), GetValueIds(op->result(i), *value_exe_info)); + } + SetOutputs(outputs); +} + +inline int GetBranchNumber(const phi::DenseTensor &mask) { + PADDLE_ENFORCE_EQ( + mask.numel(), + 1, + phi::errors::Fatal("The numel of Input(Mask) in SelectInputOp or " + "SelectOutputOp must be 1. " + "But received %d, and it's shape is [%s].", + mask.numel(), + mask.dims())); + if (platform::is_cpu_place(mask.place())) { + return mask.data()[0]; + } + // when platform::is_gpu_place(mask.place()) is true + std::unique_ptr cpu_mask{new phi::DenseTensor()}; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU) + framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get()); +#else + PADDLE_THROW(phi::errors::Fatal( + "This version of PaddlePaddle does NOT support GPU, " + "but got GPU tensor 'Mask' in SelectInputOp or SelectOutputOp. " + "Please compile PaddlePaddle WITH_GPU first.")); +#endif + return cpu_mask->data()[0]; +} + +class AssignFunctor { + public: + explicit AssignFunctor(Variable *out) : out_(out) {} + + void operator()(const phi::DenseTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const phi::TensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const phi::SelectedRows &rows) const { + phi::SelectedRows &out_rows = *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + auto *m = out_rows.mutable_value(); + TensorCopy(t, t.place(), m); + } + + template + void operator()(const T &v UNUSED) const { + PADDLE_ENFORCE_EQ( + true, + false, + platform::errors::PermissionDenied( + "Not support type for assign op with type %s", typeid(T).name())); + } + + private: + void copy_tensor(const phi::DenseTensor &lod_tensor, + phi::DenseTensor *out) const { + if (!lod_tensor.IsInitialized()) return; + auto &out_tensor = *out; + TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } + + Variable *out_; +}; + +void SelectOutputInstruction::Run() { + VLOG(6) << "run select_output instruction"; + auto &mask = mask_->Get(); + size_t output_branch = static_cast(GetBranchNumber(mask)); + PADDLE_ENFORCE_LE( + output_branch, + outputs_.size(), + phi::errors::Fatal( + "Input 'Mask' in SelectInputOp is invalid. " + "'Mask' must be less than the size of output vector 'X'. " + "But received Mask = %d, Out's size = %d.", + output_branch, + outputs_.size())); + Variable *selected = outputs_[output_branch]; + VisitVarType(*input_, AssignFunctor(selected)); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.h new file mode 100644 index 00000000000000..9b19511f29c884 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace paddle { +namespace framework { +class ValueExecutionInfo; + +class SelectOutputInstruction : public InstructionBase { + public: + SelectOutputInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + ValueExecutionInfo* value_exe_info); + + void Run() override; + + const std::string& Name() const override { return name_; } + + ::pir::Operation* Operation() const override { return op_; } + + private: + ::pir::Operation* op_; + + OpFuncType type_; + + std::string name_{"pd_op.select_output"}; + + Variable* mask_; // not owned + + Variable* input_; // not owned + + std::vector outputs_; // not owned +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index a0b270bc39b8f3..94a48165ff9d4d 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -53,6 +53,7 @@ #include "paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h" @@ -725,6 +726,8 @@ void PirInterpreter::BuildInstruction() { CREATE_INSTR(AssertInstruction); } else if (op.isa()) { CREATE_INSTR(SelectInputInstruction); + } else if (op.isa()) { + CREATE_INSTR(SelectOutputInstruction); } else { PADDLE_THROW(platform::errors::Unimplemented( "Now only support pd_kernel and cinn dialect.")); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index bd66bb241f4f83..d9fd12e53da801 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -2010,7 +2010,6 @@ struct SelectInputOpTranscriber : public OpTranscriber { VarDesc* var = op_desc.Block()->FindVarRecursive(Out_name); arg_to_idx[var->Name()] = {0, 0}; - // NOTE(zhangbo): Only support auto input1 = op_inputs[1].type(); auto input2 = op_inputs[2].type(); if (input1 == input2) { @@ -2115,6 +2114,52 @@ struct SelectInputOpTranscriber : public OpTranscriber { } }; +struct SelectOutputOpTranscriber : public OpTranscriber { + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Block* block) override { + VLOG(10) << "[op select_output] start transcribing"; + auto op_info = this->LoopkUpOpInfo(ctx, op_desc); + + std::vector op_inputs = {}; + auto Mask_name = op_desc.Input("Mask")[0]; + auto& Input_name = op_desc.Input("X")[0]; + IR_ENFORCE(param_map->count(Mask_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + Mask_name); + op_inputs.push_back(param_map->at(Mask_name).value); + IR_ENFORCE(param_map->count(Input_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + Input_name); + op_inputs.push_back(param_map->at(Input_name).value); + + pir::AttributeMap attribute_map; + TranslateOpDistAttribute(op_desc, &attribute_map); + + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types; + auto Out_names = op_desc.Output("Out"); + IR_ENFORCE(Out_names.size() == 2, + "Expected SelectOutput's output size is 2."); + for (size_t idx = 0; idx < Out_names.size(); idx++) { + VarDesc* var = op_desc.Block()->FindVarRecursive(Out_names[idx]); + arg_to_idx[var->Name()] = {idx, 0}; + op_output_types.push_back(op_inputs[1].type()); + } + + pir::Operation* operation = pir::Operation::Create( + op_inputs, attribute_map, op_output_types, op_info); + block->push_back(operation); + RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); + + VLOG(10) << "[op assign_value] translation finished"; + return operation; + } +}; + pir::OpResult TranslateNumClassesForOneHot( pir::IrContext* ctx, TranslationContext* param_map, @@ -3088,6 +3133,7 @@ OpTranslator::OpTranslator() { special_handlers["mul"] = MulOpTranscriber(); special_handlers["mul_grad"] = MulGradOpTranscriber(); special_handlers["select_input"] = SelectInputOpTranscriber(); + special_handlers["select_output"] = SelectOutputOpTranscriber(); // To adapt LodTensorArray special_handlers["lod_array_length"] = LodArrayLengthOpTranscriber(); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 7eca5767750b9a..c1e84a865041ea 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -39,6 +39,7 @@ #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/dialect/control_flow/ir/cf_type.h" namespace paddle { namespace translator { @@ -57,7 +58,6 @@ const std::unordered_set ProgramTranslator::no_cast_var_names = { }; const std::unordered_set ProgramTranslator::unsupported_ops = { - "conditional_block_grad", "while_grad", }; @@ -121,6 +121,41 @@ TranslationContext* TranslationContext::CreateInnerContext() { return sons_.back().get(); } +static std::vector GetExternalInputs(const BlockDesc& block) { + std::vector external_inputs; + std::unordered_set inner_outputs; + for (auto op_desc : block.AllOps()) { + for (const auto& n : op_desc->Inputs()) { + const auto& input_var_names = n.second; + for (const auto& var_name : input_var_names) { + if (inner_outputs.count(var_name) == 0) { + external_inputs.push_back(var_name); + } + } + } + for (const auto& n : op_desc->Outputs()) { + const auto& output_var_names = n.second; + for (const auto& var_name : output_var_names) { + inner_outputs.insert(var_name); + } + } + } + return external_inputs; +} + +static std::vector GetInnerOutputs(const BlockDesc& block) { + std::vector inner_outputs; + for (auto op_desc : block.AllOps()) { + for (const auto& n : op_desc->Outputs()) { + const auto& output_var_names = n.second; + for (const auto& var_name : output_var_names) { + inner_outputs.push_back(var_name); + } + } + } + return inner_outputs; +} + ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, pir::Program* program) : legacy_program_(legacy_program), program_(program) { @@ -133,6 +168,8 @@ void ProgramTranslator::Translate() { InsertDataOpForSingleBlock(legacy_program_->Block(0)); + PreAnalysisForCond(); + TranslateBlock(legacy_program_->Block(0), 0, legacy_program_->Block(0).OpSize(), @@ -179,6 +216,8 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, if (op->Type() == "conditional_block") { TranslateIfOperation(op, translation_ctx, dst_block); + } else if (op->Type() == "conditional_block_grad") { + TranslateIfOperation(op, translation_ctx, dst_block, true); } else if (op->Type() == "while") { TranslateWhileOperation(op, translation_ctx, dst_block); } else { @@ -187,7 +226,7 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, } } -pir::Operation* ProgramTranslator::InsertDataOpOrCreateArrayToBlock( +pir::Operation* ProgramTranslator::InsertInitOpOrCreateArrayToBlock( pir::Block* insert_block, pir::Type type) { pir::Builder builder(ctx_, insert_block, insert_block->begin()); if (type.isa()) { @@ -207,13 +246,12 @@ pir::Operation* ProgramTranslator::InsertDataOpOrCreateArrayToBlock( VLOG(10) << "[translator][data insertion] after type: " << normalized_tensor_type; - auto data_op = builder.Build( - "data_" + nano_timestamp(), + auto init_op = builder.Build( shape, + 0.0, paddle::dialect::TransToPhiDataType(normalized_tensor_type.dtype()), phi::CPUPlace()); - data_op.out().set_type(normalized_tensor_type); - return data_op.operation(); + return init_op.operation(); } else if (type.isa()) { auto array_type = type.dyn_cast(); paddle::dialect::CreateArrayOp array_op = @@ -225,12 +263,52 @@ pir::Operation* ProgramTranslator::InsertDataOpOrCreateArrayToBlock( return nullptr; } -// NOTE(zhangbo): All condition_block_op will be translated as an if_op with -// only a true branch. +// NOTE(zhangbo): This function is used to analyze cond and cond_grad in +// program, insert tuple_push and tuple_pop op for forward variables used in +// backward operator. +void ProgramTranslator::PreAnalysisForCond() { + const BlockDesc& block = legacy_program_->Block(0); + std::unordered_map scope_var_to_cond; + for (uint64_t op_id = 0; op_id < block.OpSize(); op_id++) { + auto op = block.Op(static_cast(op_id)); + if (op->Type() == "conditional_block") { + scope_var_to_cond[op->Output("Scope")[0]] = op; + } else if (op->Type() == "conditional_block_grad") { + cond_grad_to_cond_[op] = scope_var_to_cond[op->Input("Scope")[0]]; + } + } + for (auto& pair : cond_grad_to_cond_) { + const OpDesc* cond_op = pair.second; + const BlockDesc& cond_block = + legacy_program_->Block(cond_op->GetBlockAttrId("sub_block")); + const OpDesc* cond_grad_op = pair.first; + const BlockDesc& cond_grad_block = + legacy_program_->Block(cond_grad_op->GetBlockAttrId("sub_block")); + + std::vector cond_inner_outputs = GetInnerOutputs(cond_block); + auto cond_grad_external_inputs = GetExternalInputs(cond_grad_block); + + std::vector push_pop_vars; + for (auto& var : cond_grad_external_inputs) { + if (std::find(cond_inner_outputs.begin(), + cond_inner_outputs.end(), + var) != cond_inner_outputs.end()) { + push_pop_vars.emplace_back(var); + } + } + push_pop_var_names_[cond_op] = push_pop_vars; + push_pop_var_names_[cond_grad_op] = push_pop_vars; + } +} + +// NOTE(zhangbo): All condition_block_op will be translated as an if_op with a +// true branch and a fake false branch, fake false branch will insert some init +// op such as full to facilitate alignment of the outputs of the true branch. void ProgramTranslator::TranslateIfOperation( const OpDesc* op, TranslationContext* translation_ctx, - pir::Block* dst_block) { + pir::Block* dst_block, + bool for_bwd) { VLOG(8) << "=============>Start to translate if op:" << op; auto& type_translator = TypeTranslator::instance(); @@ -240,7 +318,25 @@ void ProgramTranslator::TranslateIfOperation( VLOG(6) << "[general op][conditional_block][inputs: " << input_name << "]"; GetValueOrCreateInTop(input_name, translation_ctx); } - auto& cond_op_outputs = op->Output("Out"); + if (for_bwd) { + auto& cond_op_outs = op->Input("Out"); + for (auto out_name : cond_op_outs) { + VLOG(6) << "[general op][conditional_block][outs: " << out_name << "]"; + GetValueOrCreateInTop(out_name, translation_ctx); + } + auto& cond_op_out_grads = op->Input("Out@GRAD"); + for (auto out_grad_name : cond_op_out_grads) { + VLOG(6) << "[general op][conditional_block][out_grad: " << out_grad_name + << "]"; + GetValueOrCreateInTop(out_grad_name, translation_ctx); + } + } + + std::vector cond_op_outputs = + for_bwd ? op->Output("Input@GRAD") : op->Output("Out"); + cond_op_outputs.erase( + std::remove(cond_op_outputs.begin(), cond_op_outputs.end(), "@EMPTY@"), + cond_op_outputs.end()); std::vector<::paddle::framework::VarDesc*> cond_op_output_vars; for (auto out_name : cond_op_outputs) { cond_op_output_vars.emplace_back(op->Block()->FindVarRecursive(out_name)); @@ -257,27 +353,81 @@ void ProgramTranslator::TranslateIfOperation( if_op_output_types.emplace_back(translated_var_type); } auto if_op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); - pir::Operation* operation = pir::Operation::Create( + pir::Operation* if_op = pir::Operation::Create( if_op_inputs, attribute_map, if_op_output_types, if_op_info, 2); - dst_block->push_back(operation); + // NOTE(zhangbo): If program has if_grad_op and if_grad_op sub_block use some + // value defined in if_op, we should insert tuple_push_op into if_op sub_block + // and tuple_pop_op into if_grad_op sub_block. + if (!for_bwd && push_pop_var_names_[op].size() != 0) { + pir::Operation* create_stack_op = pir::Operation::Create( + {}, + {}, + {pir::StackType::get(ctx_), + pir::InletType::get(ctx_), + pir::OutletType::get(ctx_)}, + ctx_->GetRegisteredOpInfo(pir::StackCreateOp::name())); + dst_block->push_back(create_stack_op); + cond_to_stack_value_[op] = { + create_stack_op->dyn_cast().inlet(), + create_stack_op->dyn_cast().outlet()}; + } + + dst_block->push_back(if_op); VLOG(4) << "[general op][conditional_block] IfOp creation end."; if (op->GetBlockAttrId("sub_block") != -1) { // Translate true branch by sub_block. auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); - pir::Region& true_region = operation->region(0); + pir::Region& true_region = if_op->region(0); if (true_region.empty()) true_region.emplace_back(); auto* true_block_context = translation_ctx->CreateInnerContext(); + + // insert tuple_pop op to if_grad + if (for_bwd && push_pop_var_names_[op].size() != 0) { + pir::Operation* tuple_pop_op = pir::Operation::Create( + {cond_to_stack_value_[cond_grad_to_cond_[op]][1]}, + {}, + push_pop_var_types_[cond_grad_to_cond_[op]], + ctx_->GetRegisteredOpInfo(pir::TuplePopOp::name())); + true_region.front().push_back(tuple_pop_op); + for (size_t i = 0; i < push_pop_var_names_[op].size(); ++i) { + true_block_context->PushValue( + push_pop_var_names_[op][i], + VariableDefiningInfo(tuple_pop_op->result(i))); + } + } + TranslateBlock(sub_block, 0, sub_block.OpSize(), true_block_context, &true_region.front()); + + // insert tuple_push op to true block before yeild op + if (!for_bwd && push_pop_var_names_[op].size() != 0) { + std::vector local_values; + local_values.push_back(cond_to_stack_value_[op][0]); + for (auto& var_name : push_pop_var_names_[op]) { + local_values.push_back(true_block_context->at(var_name).value); + push_pop_var_types_[op].push_back( + true_block_context->at(var_name).value.type()); + } + pir::Operation* tuple_push_op = pir::Operation::Create( + local_values, + {}, + {}, + ctx_->GetRegisteredOpInfo(pir::TuplePushOp::name())); + true_region.front().push_back(tuple_push_op); + } + // insert yeild op to true block auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); std::vector true_yeild_inputs; for (auto& out_name : cond_op_outputs) { + if (for_bwd && out_name.find("@RENAME@block") != std::string::npos) { + out_name = out_name.substr(0, out_name.find("@RENAME@block")); + } true_yeild_inputs.push_back(true_block_context->at(out_name).value); } true_region.front().push_back( @@ -288,7 +438,7 @@ void ProgramTranslator::TranslateIfOperation( // consistent. Only inconsistent shape is allowed. To be compatible with the // old IR design, only true branches are allowed. The false branch may // require yeild some fake variables. - pir::Region& false_region = operation->region(1); + pir::Region& false_region = if_op->region(1); if (false_region.empty()) false_region.emplace_back(); auto* false_block_context = translation_ctx->CreateInnerContext(); std::vector false_yeild_inputs; @@ -296,7 +446,7 @@ void ProgramTranslator::TranslateIfOperation( if (false_block_context->count(cond_op_outputs[id]) == 0) { auto true_type = true_yeild_inputs[id].type(); pir::Operation* init_op = - InsertDataOpOrCreateArrayToBlock(&false_region.front(), true_type); + InsertInitOpOrCreateArrayToBlock(&false_region.front(), true_type); PADDLE_ENFORCE_NOT_NULL( init_op, phi::errors::PreconditionNotMet( @@ -315,13 +465,13 @@ void ProgramTranslator::TranslateIfOperation( for (size_t i = 0; i < cond_op_output_vars.size(); i++) { translation_ctx->PushValue(cond_op_output_vars[i]->Name(), - VariableDefiningInfo(operation->result(i))); + VariableDefiningInfo(if_op->result(i))); VLOG(4) << "[general op][conditional_block] var " << cond_op_output_vars[i]->Name() << " was mapped to If's " << i << "-th output."; } - operation->Verify(); + if_op->Verify(); VLOG(4) << "[general op][conditional_block] IfOp translate end."; } @@ -426,26 +576,12 @@ void ProgramTranslator::InsertDataOpForSingleBlock(const BlockDesc& block) { for (auto& var : block.AllVars()) { all_var_names.insert(var->Name()); } - - std::unordered_set inner_outputs; - for (auto op_desc : block.AllOps()) { - for (const auto& n : op_desc->Inputs()) { - const auto& input_var_names = n.second; - for (const auto& var_name : input_var_names) { - if (param_map_.count(var_name) != 0) continue; - if (no_cast_var_names.count(var_name) != 0) continue; - if (all_var_names.count(var_name) == 0) continue; - if (inner_outputs.count(var_name) == 0) { - CreateUndefinedVariable(var_name, block); - } - } - } - for (const auto& n : op_desc->Outputs()) { - const auto& output_var_names = n.second; - for (const auto& var_name : output_var_names) { - inner_outputs.insert(var_name); - } - } + auto external_inputs = GetExternalInputs(block); + for (auto& var_name : external_inputs) { + if (all_var_names.count(var_name) == 0) continue; + if (param_map_.count(var_name) != 0) continue; + if (no_cast_var_names.count(var_name) != 0) continue; + CreateUndefinedVariable(var_name, block); } } diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index cff7684226c520..f2a9bf3c5b3c59 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -133,12 +133,20 @@ class ProgramTranslator { const VariableDefiningInfo& CreateUndefinedVariable( const std::string& var_name, const BlockDesc& block); - pir::Operation* InsertDataOpOrCreateArrayToBlock(pir::Block* insert_block, + pir::Operation* InsertInitOpOrCreateArrayToBlock(pir::Block* insert_block, pir::Type type); + std::unordered_map> + push_pop_var_names_; + std::unordered_map> push_pop_var_types_; + std::unordered_map cond_grad_to_cond_; + std::unordered_map> + cond_to_stack_value_; + void PreAnalysisForCond(); void TranslateIfOperation(const OpDesc* op, TranslationContext* translation_ctx, - pir::Block* dst_block); + pir::Block* dst_block, + bool for_bwd = false); void TranslateWhileOperation(const OpDesc* op, TranslationContext* translation_ctx, diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 30d5ce5a1b685e..92083c131175ba 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -14,7 +14,8 @@ #ifdef GET_OP_LIST #undef GET_OP_LIST paddle::dialect::IfOp, paddle::dialect::WhileOp, paddle::dialect::HasElementsOp, - paddle::dialect::AssertOp + paddle::dialect::AssertOp, paddle::dialect::SelectInputOp, + paddle::dialect::SelectOutputOp #else #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" @@ -671,6 +672,159 @@ void AssertOp::VerifySig() { VLOG(4) << "End Verifying for: AssertOp."; } +void SelectInputOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: SelectInputOp."; + VLOG(4) << "Verifying inputs:"; + { + auto in_size = num_operands(); + IR_ENFORCE(in_size == 3u, "Size %d of inputs must be 3.", in_size); + auto input1 = (*this)->operand_source(1).type(); + auto input2 = (*this)->operand_source(2).type(); + if (input1.isa() && + input2.isa()) { + auto tensor1 = input1.dyn_cast(); + auto tensor2 = input2.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + } else if (input1.isa() && + input2.isa()) { + auto tensor1 = + input1.dyn_cast(); + auto tensor2 = + input1.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + IR_ENFORCE( + tensor1.place() == tensor2.place(), + "The 1st input place %s should be equal to 2ed input place %s.", + tensor1.place(), + tensor2.place()); + } else { + IR_ENFORCE(input1 == input2, + "The 1st input type %s should be equal to 2ed input type %s.", + input1, + input2); + } + } + VLOG(4) << "Verifying outputs:"; + { + auto out_size = num_results(); + IR_ENFORCE( + out_size == 1u, "Size %d of outputs must be equal to 1.", out_size); + } + VLOG(4) << "End Verifying for: AssignArray_Op."; +} + +void SelectOutputOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: SelectOutputOp."; + VLOG(4) << "Verifying inputs:"; + { + auto in_size = num_operands(); + IR_ENFORCE(in_size == 2u, "Size %d of inputs must be 2.", in_size); + } + VLOG(4) << "Verifying outputs:"; + { + auto out_size = num_results(); + IR_ENFORCE( + out_size == 2u, "Size %d of outputs must be equal to 2.", out_size); + + auto out1 = (*this)->result(0).type(); + auto out2 = (*this)->result(1).type(); + if (out1.isa() && + out2.isa()) { + auto tensor1 = out1.dyn_cast(); + auto tensor2 = out2.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + } else if (out1.isa() && + out2.isa()) { + auto tensor1 = out1.dyn_cast(); + auto tensor2 = out2.dyn_cast(); + IR_ENFORCE( + tensor1.dtype() == tensor2.dtype(), + "The 1st input dtype %s should be equal to 2ed input dtype %s.", + tensor1.dtype(), + tensor2.dtype()); + IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), + "The 1st input data_layout %s should be equal to 2ed input " + "data_layout %s.", + tensor1.data_layout(), + tensor2.data_layout()); + IR_ENFORCE(tensor1.lod() == tensor2.lod(), + "The 1st input lod %s should be equal to 2ed input lod %s.", + tensor1.lod(), + tensor2.lod()); + IR_ENFORCE( + tensor1.offset() == tensor2.offset(), + "The 1st input offset %s should be equal to 2ed input offset %s.", + tensor1.offset(), + tensor2.offset()); + IR_ENFORCE( + tensor1.place() == tensor2.place(), + "The 1st input place %s should be equal to 2ed input place %s.", + tensor1.place(), + tensor2.place()); + } else { + IR_ENFORCE(out1 == out2, + "The 1st input type %s should be equal to 2ed input type %s.", + out1, + out2); + } + } + VLOG(4) << "End Verifying for: AssignArray_Op."; +} + } // namespace dialect } // namespace paddle @@ -678,5 +832,7 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::HasElementsOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssertOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectOutputOp) #endif diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 3c86d56d116165..90df8eb2ff8f99 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -161,6 +161,28 @@ class AssertOp : public pir::Op { pir::Value data() { return operand_source(1); } }; +class SelectInputOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.select_input"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + void VerifySig(); + pir::Value mask() { return operand_source(0); } + pir::OpResult out() { return result(0); } +}; + +class SelectOutputOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.select_output"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + void VerifySig(); + pir::Value mask() { return operand_source(0); } + pir::Value x() { return operand_source(1); } +}; + } // namespace dialect } // namespace paddle @@ -168,3 +190,5 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::HasElementsOp); IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AssertOp); +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectOutputOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index f4deb33b659263..682400c1d3033d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -23,9 +23,8 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArrayOp, paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp, paddle::dialect::TensorToArrayOp, - paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp, - paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp, - paddle::dialect::MemcpyD2hMultiIoOp + paddle::dialect::IncrementOp, paddle::dialect::Increment_Op, + paddle::dialect::ShapeBroadcastOp, paddle::dialect::MemcpyD2hMultiIoOp #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" @@ -4329,83 +4328,6 @@ phi::DataType ExpandOp::GetKernelTypeForVar( return expected_kernel_dtype; } -void SelectInputOp::VerifySig() { - VLOG(4) << "Verifying inputs, outputs and attributes for: SelectInputOp."; - VLOG(4) << "Verifying inputs:"; - { - auto in_size = num_operands(); - IR_ENFORCE(in_size == 3u, "Size %d of inputs must be >= 3.", in_size); - auto input1 = (*this)->operand_source(1).type(); - auto input2 = (*this)->operand_source(2).type(); - if (input1.isa() && - input2.isa()) { - auto tensor1 = input1.dyn_cast(); - auto tensor2 = input1.dyn_cast(); - IR_ENFORCE( - tensor1.dtype() == tensor2.dtype(), - "The 1st input dtype %s should be equal to 2ed input dtype %s.", - tensor1.dtype(), - tensor2.dtype()); - IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), - "The 1st input data_layout %s should be equal to 2ed input " - "data_layout %s.", - tensor1.data_layout(), - tensor2.data_layout()); - IR_ENFORCE(tensor1.lod() == tensor2.lod(), - "The 1st input lod %s should be equal to 2ed input lod %s.", - tensor1.lod(), - tensor2.lod()); - IR_ENFORCE( - tensor1.offset() == tensor2.offset(), - "The 1st input offset %s should be equal to 2ed input offset %s.", - tensor1.offset(), - tensor2.offset()); - } else if (input1.isa() && - input2.isa()) { - auto tensor1 = - input1.dyn_cast(); - auto tensor2 = - input1.dyn_cast(); - IR_ENFORCE( - tensor1.dtype() == tensor2.dtype(), - "The 1st input dtype %s should be equal to 2ed input dtype %s.", - tensor1.dtype(), - tensor2.dtype()); - IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), - "The 1st input data_layout %s should be equal to 2ed input " - "data_layout %s.", - tensor1.data_layout(), - tensor2.data_layout()); - IR_ENFORCE(tensor1.lod() == tensor2.lod(), - "The 1st input lod %s should be equal to 2ed input lod %s.", - tensor1.lod(), - tensor2.lod()); - IR_ENFORCE( - tensor1.offset() == tensor2.offset(), - "The 1st input offset %s should be equal to 2ed input offset %s.", - tensor1.offset(), - tensor2.offset()); - IR_ENFORCE( - tensor1.place() == tensor2.place(), - "The 1st input place %s should be equal to 2ed input place %s.", - tensor1.place(), - tensor2.place()); - } else { - IR_ENFORCE(input1 == input2, - "The 1st input type %s should be equal to 2ed input type %s.", - input1, - input2); - } - } - VLOG(4) << "Verifying outputs:"; - { - auto out_size = num_results(); - IR_ENFORCE( - out_size == 1u, "Size %d of outputs must be equal to 1.", out_size); - } - VLOG(4) << "End Verifying for: AssignArray_Op."; -} - const char *IncrementOp::attributes_name[1] = {"value"}; OpInfoTuple IncrementOp::GetOpInfo() { @@ -5253,7 +5175,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignArray_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayToTensorOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorToArrayOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::MemcpyD2hMultiIoOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 6fb238a435c44a..983080049fc4f6 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -578,17 +578,6 @@ class ExpandOp : public pir::Op { - public: - using Op::Op; - static const char *name() { return "pd_op.select_input"; } - static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 0; - void VerifySig(); - pir::Value mask() { return operand_source(0); } - pir::OpResult out() { return result(0); } -}; - class IncrementOp : public pir::Op SpecialLowerOps = { HasElementsOp::name(), AssertOp::name(), SelectInputOp::name(), + SelectOutputOp::name(), "cinn_runtime.jit_kernel"}; const std::unordered_map NoBufferRelatedOps = { @@ -1162,7 +1163,8 @@ void HandleForIfOp( &true_block, ctx, map_op_pair, - map_value_pair); + map_value_pair, + true); // process false block auto& false_block = new_ifop.false_block(); @@ -1171,7 +1173,8 @@ void HandleForIfOp( &false_block, ctx, map_op_pair, - map_value_pair); + map_value_pair, + true); // update map (*map_op_pair)[op_item] = new_ifop; @@ -1253,6 +1256,27 @@ phi::Place ParsePhiPlace(pir::Type type) { return type.dyn_cast().place(); } else if (type.isa()) { return type.dyn_cast().place(); + } else if (type.isa()) { + return ParsePhiPlace(type.dyn_cast()[0]); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "ParsePhiPlace only support AllocatedDenseTensorType or " + "AllocatedSelectedRowsType or AllocatedDenseTensorArrayType")); + } +} + +phi::DataType ParsePhiDType(pir::Type type) { + if (type.isa()) { + return TransToPhiDataType( + type.dyn_cast().dtype()); + } else if (type.isa()) { + return TransToPhiDataType( + type.dyn_cast().dtype()); + } else if (type.isa()) { + return TransToPhiDataType( + type.dyn_cast().dtype()); + } else if (type.isa()) { + return ParsePhiDType(type.dyn_cast()[0]); } else { PADDLE_THROW(phi::errors::Unimplemented( "ParsePhiPlace only support AllocatedDenseTensorType or " @@ -1266,7 +1290,8 @@ void HandleForSpecialOp( pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair) { + std::unordered_map* map_value_pair, + bool for_if_block) { if (op_item->isa()) { HandleForIfOp(place, op_item, block, ctx, map_op_pair, map_value_pair); return; @@ -1373,6 +1398,23 @@ void HandleForSpecialOp( } auto new_in = GetNewInput( cur_in, *map_value_pair, static_cast(i), op_item->name()); + + if (for_if_block && (!new_in.type().isa()) && + (ParsePhiPlace(new_in.type()).GetType() != + phi::AllocationType::UNDEFINED) && + (ParsePhiPlace(new_in.type()) != place)) { + phi::KernelKey kernel_key(TransToPhiBackend(place), + phi::DataLayout::ALL_LAYOUT, + ParsePhiDType(new_in.type())); + new_in = AddPlaceTransferOp( + new_in, + ConvertOpTypeToKernelType(ctx, cur_in.type(), place), + ParsePhiPlace(new_in.type()), + place, + kernel_key, + block); + } + vec_inputs.push_back(new_in); } } @@ -1388,6 +1430,7 @@ void HandleForSpecialOp( } auto new_in = GetNewInput( cur_in, *map_value_pair, static_cast(i), op_item->name()); + // layout transfer(only for onednn) #ifdef PADDLE_WITH_DNNL auto new_in_type = new_in.type(); @@ -1532,6 +1575,18 @@ void HandleForSpecialOp( } } + if (op_item->isa()) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + vec_inputs.push_back(new_in); + } + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_output_types.push_back(vec_inputs[1].type()); + } + } + if (op_item->name() == "cinn_runtime.jit_kernel") { if (op_item->num_operands() > 0) { for (size_t i = 0; i < op_item->num_operands(); ++i) { @@ -2319,7 +2374,8 @@ void ProcessBlock( pir::Block* new_block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair) { + std::unordered_map* map_value_pair, + bool for_if_block) { auto inputs_by_data_op = GetInputsByDataOp(block); for (auto iter = block->begin(); iter != block->end(); ++iter) { @@ -2338,8 +2394,13 @@ void ProcessBlock( if (SpecialLowerOps.count(op_item->name())) { VLOG(6) << "Handle Special Op: [" << op_item->name() << "] while lowering to kernel pass"; - HandleForSpecialOp( - place, op_item, new_block, ctx, map_op_pair, map_value_pair); + HandleForSpecialOp(place, + op_item, + new_block, + ctx, + map_op_pair, + map_value_pair, + for_if_block); continue; } diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h index 59415d931a6f89..395bc3fa6ec081 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h @@ -30,6 +30,7 @@ void ProcessBlock( pir::Block* new_block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair); + std::unordered_map* map_value_pair, + bool for_if_block = false); } // namespace dialect } // namespace paddle diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index ff018eed8db8ed..96d10d1c987663 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -711,7 +711,8 @@ def backward_value_helper(self, cond_func, use_cuda): name='image', shape=[-1, 9], dtype='float32' ) img.stop_gradient = False - img.persistable = True + if paddle.framework.in_pir_mode(): + img.persistable = True label = paddle.static.data( name='label', shape=[-1, 1], dtype='int64' ) @@ -798,12 +799,13 @@ def add_optimizer_helper(self, cond_func, use_cuda): with paddle.static.scope_guard(paddle.static.Scope()): with paddle.static.program_guard(main_program, startup_program): img = paddle.static.data( - name='image', shape=[-1, 784], dtype='float32' + name='image', shape=[16, 784], dtype='float32' ) img.stop_gradient = False - img.persistable = True + if paddle.framework.in_pir_mode(): + img.persistable = True label = paddle.static.data( - name='label', shape=[-1, 1], dtype='int64' + name='label', shape=[16, 1], dtype='int64' ) i = paddle.static.data(name="i", shape=[1], dtype='int32') loss = cond_func(i, img, label) @@ -829,7 +831,7 @@ def add_optimizer_helper(self, cond_func, use_cuda): fetch_list=[loss], ) - @test_with_pir_api + @compare_legacy_with_pt def test_cond_backward(self): paddle.enable_static()