Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"

namespace paddle {
namespace framework {

SelectOutputInstruction::SelectOutputInstruction(
size_t id,
const platform::Place &place,
::pir::Operation *op,
ValueExecutionInfo *value_exe_info)
: InstructionBase(id, place), op_(op) {
VLOG(6) << "construct select_output instruction";

std::unordered_map<pir::Value, std::vector<int>> inputs;
mask_ = value_exe_info->GetVarByValue(op->operand_source(0));
inputs.emplace(op->operand_source(0),
GetValueIds(op->operand_source(0), *value_exe_info));
input_ = value_exe_info->GetVarByValue(op->operand_source(1));
inputs.emplace(op->operand_source(1),
GetValueIds(op->operand_source(1), *value_exe_info));
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
for (size_t i = 0; i < op->num_results(); ++i) {
outputs_.push_back(value_exe_info->GetVarByValue(op->result(i)));
outputs.emplace(op->result(i), GetValueIds(op->result(i), *value_exe_info));
}
SetOutputs(outputs);
}

inline int GetBranchNumber(const phi::DenseTensor &mask) {
PADDLE_ENFORCE_EQ(
mask.numel(),
1,
phi::errors::Fatal("The numel of Input(Mask) in SelectInputOp or "
"SelectOutputOp must be 1. "
"But received %d, and it's shape is [%s].",
mask.numel(),
mask.dims()));
if (platform::is_cpu_place(mask.place())) {
return mask.data<int>()[0];
}
// when platform::is_gpu_place(mask.place()) is true
std::unique_ptr<phi::DenseTensor> cpu_mask{new phi::DenseTensor()};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU)
framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get());
#else
PADDLE_THROW(phi::errors::Fatal(
"This version of PaddlePaddle does NOT support GPU, "
"but got GPU tensor 'Mask' in SelectInputOp or SelectOutputOp. "
"Please compile PaddlePaddle WITH_GPU first."));
#endif
return cpu_mask->data<int>()[0];
}

class AssignFunctor {
public:
explicit AssignFunctor(Variable *out) : out_(out) {}

void operator()(const phi::DenseTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<phi::DenseTensor>();
copy_tensor(lod_tensor, &out_tensor);
}

void operator()(const phi::TensorArray &array) const {
auto &out_array = *out_->GetMutable<phi::TensorArray>();
out_array.resize(array.size());
for (size_t i = 0; i < array.size(); ++i) {
copy_tensor(array[i], &out_array[i]);
}
}

void operator()(const phi::SelectedRows &rows) const {
phi::SelectedRows &out_rows = *out_->GetMutable<phi::SelectedRows>();
out_rows.set_rows(rows.rows());
out_rows.set_height(rows.height());
auto &t = rows.value();
auto *m = out_rows.mutable_value();
TensorCopy(t, t.place(), m);
}

template <typename T>
void operator()(const T &v UNUSED) const {
PADDLE_ENFORCE_EQ(
true,
false,
platform::errors::PermissionDenied(
"Not support type for assign op with type %s", typeid(T).name()));
}

private:
void copy_tensor(const phi::DenseTensor &lod_tensor,
phi::DenseTensor *out) const {
if (!lod_tensor.IsInitialized()) return;
auto &out_tensor = *out;
TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor);
out_tensor.set_lod(lod_tensor.lod());
}

Variable *out_;
};

void SelectOutputInstruction::Run() {
VLOG(6) << "run select_output instruction";
auto &mask = mask_->Get<phi::DenseTensor>();
size_t output_branch = static_cast<size_t>(GetBranchNumber(mask));
PADDLE_ENFORCE_LE(
output_branch,
outputs_.size(),
phi::errors::Fatal(
"Input 'Mask' in SelectInputOp is invalid. "
"'Mask' must be less than the size of output vector 'X'. "
"But received Mask = %d, Out's size = %d.",
output_branch,
outputs_.size()));
Variable *selected = outputs_[output_branch];
VisitVarType(*input_, AssignFunctor(selected));
}

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"

namespace paddle {
namespace framework {
class ValueExecutionInfo;

class SelectOutputInstruction : public InstructionBase {
public:
SelectOutputInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

void Run() override;

const std::string& Name() const override { return name_; }

::pir::Operation* Operation() const override { return op_; }

private:
::pir::Operation* op_;

OpFuncType type_;

std::string name_{"pd_op.select_output"};

Variable* mask_; // not owned

Variable* input_; // not owned

std::vector<Variable*> outputs_; // not owned
};

} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/control_flow/select_output_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h"
Expand Down Expand Up @@ -725,6 +726,8 @@ void PirInterpreter::BuildInstruction() {
CREATE_INSTR(AssertInstruction);
} else if (op.isa<paddle::dialect::SelectInputOp>()) {
CREATE_INSTR(SelectInputInstruction);
} else if (op.isa<paddle::dialect::SelectOutputOp>()) {
CREATE_INSTR(SelectOutputInstruction);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Now only support pd_kernel and cinn dialect."));
Expand Down
48 changes: 47 additions & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,6 @@ struct SelectInputOpTranscriber : public OpTranscriber {
VarDesc* var = op_desc.Block()->FindVarRecursive(Out_name);
arg_to_idx[var->Name()] = {0, 0};

// NOTE(zhangbo): Only support
auto input1 = op_inputs[1].type();
auto input2 = op_inputs[2].type();
if (input1 == input2) {
Expand Down Expand Up @@ -2115,6 +2114,52 @@ struct SelectInputOpTranscriber : public OpTranscriber {
}
};

struct SelectOutputOpTranscriber : public OpTranscriber {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个看起来没必要做特殊转换?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前仅支持输出个数为2的 select_output,所以暂时写了一个特殊转换

IR_ENFORCE(Out_names.size() == 2,
               "Expected SelectOutput's output size is 2.");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只是为了Enforce一下的话是不是先Enforce下,再直接调基类方法就可以了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SelectOuput 旧算子体系下,其输出 Out 是含有多个元素的 list,新算子定义体系下,我们定义 SelectOuput 有两个输出,在旧算子某个输出多个元素的场景下,GenerateOperationOutput 仅支持了输出是 VectorType 或者只取多个元素的第一个元素。

pir::Operation* operator()(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
pir::Block* block) override {
VLOG(10) << "[op select_output] start transcribing";
auto op_info = this->LoopkUpOpInfo(ctx, op_desc);

std::vector<pir::Value> op_inputs = {};
auto Mask_name = op_desc.Input("Mask")[0];
auto& Input_name = op_desc.Input("X")[0];
IR_ENFORCE(param_map->count(Mask_name) > 0,
"Expected op[%s]'s input %s has been parsed",
op_desc.Type(),
Mask_name);
op_inputs.push_back(param_map->at(Mask_name).value);
IR_ENFORCE(param_map->count(Input_name) > 0,
"Expected op[%s]'s input %s has been parsed",
op_desc.Type(),
Input_name);
op_inputs.push_back(param_map->at(Input_name).value);

pir::AttributeMap attribute_map;
TranslateOpDistAttribute(op_desc, &attribute_map);

OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types;
auto Out_names = op_desc.Output("Out");
IR_ENFORCE(Out_names.size() == 2,
"Expected SelectOutput's output size is 2.");
for (size_t idx = 0; idx < Out_names.size(); idx++) {
VarDesc* var = op_desc.Block()->FindVarRecursive(Out_names[idx]);
arg_to_idx[var->Name()] = {idx, 0};
op_output_types.push_back(op_inputs[1].type());
}

pir::Operation* operation = pir::Operation::Create(
op_inputs, attribute_map, op_output_types, op_info);
block->push_back(operation);
RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);

VLOG(10) << "[op assign_value] translation finished";
return operation;
}
};

pir::OpResult TranslateNumClassesForOneHot(
pir::IrContext* ctx,
TranslationContext* param_map,
Expand Down Expand Up @@ -3088,6 +3133,7 @@ OpTranslator::OpTranslator() {
special_handlers["mul"] = MulOpTranscriber();
special_handlers["mul_grad"] = MulGradOpTranscriber();
special_handlers["select_input"] = SelectInputOpTranscriber();
special_handlers["select_output"] = SelectOutputOpTranscriber();

// To adapt LodTensorArray
special_handlers["lod_array_length"] = LodArrayLengthOpTranscriber();
Expand Down
Loading