Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ cc_library(
SRCS instruction_base.cc
phi_kernel_instruction.cc
legacy_kernel_instruction.cc
cond_instruction.cc
if_instruction.cc
while_instruction.cc
select_input_instruction.cc
has_elements_instruction.cc
tuple_push_instruction.cc
tuple_pop_instruction.cc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/instruction/cond_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h"

#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
Expand All @@ -39,11 +39,11 @@
namespace paddle {
namespace framework {

CondInstruction::CondInstruction(size_t id,
const platform::Place& place,
pir::Operation* op,
ValueExecutionInfo* value_exec_info,
const std::set<std::string>& skip_gc_vars)
IfInstruction::IfInstruction(size_t id,
const platform::Place& place,
pir::Operation* op,
ValueExecutionInfo* value_exec_info,
const std::set<std::string>& skip_gc_vars)
: InstructionBase(id, place) {
PADDLE_ENFORCE(
op->isa<paddle::dialect::IfOp>(),
Expand All @@ -66,12 +66,14 @@ CondInstruction::CondInstruction(size_t id,
// OpOperand of IfOp, and the other is external Values used in true_block or
// false_block.
auto& true_branch_block = if_op.true_block();
auto& false_branch_block = if_op.false_block();

std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *value_exec_info, &inputs);
auto true_outside_inputs =
GetExternalInputs(&true_branch_block, *value_exec_info, &inputs);
auto false_outside_inputs =
std::vector<pir::Value> false_outside_inputs;
auto& false_branch_block = if_op.false_block();
false_outside_inputs =
GetExternalInputs(&false_branch_block, *value_exec_info, &inputs);
SetInputs(inputs);

Expand All @@ -90,8 +92,10 @@ CondInstruction::CondInstruction(size_t id,
}
}
InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs);

InsertTuplePushContinerToOuts(
&false_branch_block, *value_exec_info, &outputs);
&if_op.false_block(), *value_exec_info, &outputs);

SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";

Expand Down Expand Up @@ -126,11 +130,10 @@ CondInstruction::CondInstruction(size_t id,
false_branch_inter_ =
new PirInterpreter(place,
{},
&false_branch_block,
&if_op.false_block(),
Comment on lines 130 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议改用make_unique

false_scope,
value_exec_info->NewChild(false_scope),
{});

std::set<std::string> false_skip_gc_names_set;
for (auto value : GetYiedOpInputs(&false_branch_block)) {
false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value));
Expand All @@ -146,10 +149,11 @@ CondInstruction::CondInstruction(size_t id,
false_skip_gc_names_set.insert(var_name);
}
false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set);

VLOG(6) << "finish process false branch interpreter";
}

CondInstruction::~CondInstruction() {
IfInstruction::~IfInstruction() {
if (true_branch_inter_ != nullptr) {
delete true_branch_inter_;
}
Expand All @@ -158,8 +162,8 @@ CondInstruction::~CondInstruction() {
}
}

void CondInstruction::CopyBranchOutput(
const std::vector<std::string>& var_names, const PirInterpreter* inter) {
void IfInstruction::CopyBranchOutput(const std::vector<std::string>& var_names,
const PirInterpreter* inter) {
for (size_t i = 0; i < var_names.size(); ++i) {
auto* inner_var = inter->InnerScope()->GetVar(var_names[i]);

Expand All @@ -179,7 +183,7 @@ void CondInstruction::CopyBranchOutput(
}
}

void CondInstruction::Run() {
void IfInstruction::Run() {
DeviceContext().Wait();
if (cond_var_->Get<phi::DenseTensor>().data<bool>()[0]) {
true_branch_inter_->Run({}, false);
Expand All @@ -188,7 +192,6 @@ void CondInstruction::Run() {
false_branch_inter_->Run({}, false);
CopyBranchOutput(false_branch_outputs_, false_branch_inter_);
}

// copy ouptut
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ class Value;
class PirInterpreter;
class ValueExecutionInfo;

class CondInstruction : public InstructionBase {
class IfInstruction : public InstructionBase {
public:
CondInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info,
const std::set<std::string>& skip_gc_vars);
IfInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info,
const std::set<std::string>& skip_gc_vars);

~CondInstruction();
~IfInstruction();

void Run() override;

Expand All @@ -53,15 +53,15 @@ class CondInstruction : public InstructionBase {

::pir::Operation* op_;

std::string cond_name_{"cond_instruction"};
std::string cond_name_{"if_instruction"};

Variable* cond_var_;

std::vector<Variable*> output_vars_;

PirInterpreter* true_branch_inter_;
PirInterpreter* true_branch_inter_ = nullptr;

PirInterpreter* false_branch_inter_;
PirInterpreter* false_branch_inter_ = nullptr;
Comment on lines +62 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换用unique_ptr表明所有权更好吧


std::vector<std::string> true_branch_outputs_;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"

namespace paddle {
namespace framework {

SelectInputInstruction::SelectInputInstruction(
size_t id,
const platform::Place &place,
::pir::Operation *op,
ValueExecutionInfo *value_exe_info)
: InstructionBase(id, place), op_(op) {
VLOG(6) << "construct select_input instruction";

std::unordered_map<pir::Value, std::vector<int>> inputs;
mask_ = value_exe_info->GetVarByValue(op->operand_source(0));
inputs.emplace(op->operand_source(0),
GetValueIds(op->operand_source(0), *value_exe_info));

for (size_t i = 1; i < op->num_operands(); ++i) {
inputs_.push_back(value_exe_info->GetVarByValue(op->operand_source(i)));
inputs.emplace(op->operand_source(i),
GetValueIds(op->operand_source(i), *value_exe_info));
}
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
out_ = value_exe_info->GetVarByValue(op->result(0));
outputs.emplace(op->result(0), GetValueIds(op->result(0), *value_exe_info));
SetOutputs(outputs);
}

inline int GetBranchNumber(const phi::DenseTensor &mask) {
PADDLE_ENFORCE_EQ(
mask.numel(),
1,
phi::errors::Fatal("The numel of Input(Mask) in SelectInputOp or "
"SelectOutputOp must be 1. "
"But received %d, and it's shape is [%s].",
mask.numel(),
mask.dims()));
if (platform::is_cpu_place(mask.place())) {
return mask.data<int>()[0];
}
// when platform::is_gpu_place(mask.place()) is true
std::unique_ptr<phi::DenseTensor> cpu_mask{new phi::DenseTensor()};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU)
framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get());
#else
PADDLE_THROW(phi::errors::Fatal(
"This version of PaddlePaddle does NOT support GPU, "
"but got GPU tensor 'Mask' in SelectInputOp or SelectOutputOp. "
"Please compile PaddlePaddle WITH_GPU first."));
#endif
return cpu_mask->data<int>()[0];
}

class AssignFunctor {
public:
explicit AssignFunctor(Variable *out) : out_(out) {}

void operator()(const phi::DenseTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<phi::DenseTensor>();
copy_tensor(lod_tensor, &out_tensor);
}

void operator()(const phi::TensorArray &array) const {
auto &out_array = *out_->GetMutable<phi::TensorArray>();
out_array.resize(array.size());
for (size_t i = 0; i < array.size(); ++i) {
copy_tensor(array[i], &out_array[i]);
}
}

void operator()(const phi::SelectedRows &rows) const {
phi::SelectedRows &out_rows = *out_->GetMutable<phi::SelectedRows>();
out_rows.set_rows(rows.rows());
out_rows.set_height(rows.height());
auto &t = rows.value();
auto *m = out_rows.mutable_value();
TensorCopy(t, t.place(), m);
}

template <typename T>
void operator()(const T &v UNUSED) const {
PADDLE_ENFORCE_EQ(
true,
false,
platform::errors::PermissionDenied(
"Not support type for assign op with type %s", typeid(T).name()));
}

private:
void copy_tensor(const phi::DenseTensor &lod_tensor,
phi::DenseTensor *out) const {
if (!lod_tensor.IsInitialized()) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么直接return了,而不是Throw Error?

auto &out_tensor = *out;
TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor);
out_tensor.set_lod(lod_tensor.lod());
}

Variable *out_;
};

void SelectInputInstruction::Run() {
VLOG(6) << "run select_input instruction";
auto &mask = mask_->Get<phi::DenseTensor>();
size_t output_branch = static_cast<size_t>(GetBranchNumber(mask));
PADDLE_ENFORCE_LT(
output_branch,
inputs_.size(),
phi::errors::Fatal(
"Input 'Mask' in SelectInputOp is invalid. "
"'Mask' must be less than the size of input vector 'X'. "
"But received Mask = %d, X's size = %d.",
output_branch,
inputs_.size()));
Variable *selected = inputs_[output_branch];
VisitVarType(*selected, AssignFunctor(out_));
}

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"

namespace paddle {
namespace framework {
class ValueExecutionInfo;

class SelectInputInstruction : public InstructionBase {
public:
SelectInputInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

void Run() override;

const std::string& Name() const override { return name_; }

::pir::Operation* Operation() const override { return op_; }

private:
::pir::Operation* op_;

OpFuncType type_;

std::string name_{"pd_op.select_input"};

Variable* mask_; // not owned

std::vector<Variable*> inputs_; // not owned

Variable* out_; // not owned
};

} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,14 @@ void WhileInstruction::CopyOutputsToBlockArgs() {
auto* dst_tensor_array = inner_var->GetMutable<phi::TensorArray>();
dst_tensor_array->set_type(src_tensor_array.dtype());
dst_tensor_array->set_layout(src_tensor_array.layout());
if (dst_tensor_array->empty()) {
for (auto src_tensor : src_tensor_array) {
phi::DenseTensor* tmp_dst_tensor = new phi::DenseTensor();
tmp_dst_tensor->set_meta(src_tensor.meta());
framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor);
dst_tensor_array->push_back(*tmp_dst_tensor);
}
} else {
for (size_t id = 0; id < dst_tensor_array->size(); id++) {
auto& src_tensor = src_tensor_array[id];
phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id);
tmp_dst_tensor->set_meta(src_tensor.meta());
framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor);
}
while (dst_tensor_array->size() < src_tensor_array.size()) {
dst_tensor_array->emplace_back();
}
for (size_t id = 0; id < dst_tensor_array->size(); id++) {
auto& src_tensor = src_tensor_array[id];
phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id);
tmp_dst_tensor->set_meta(src_tensor.meta());
framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor);
}
} else {
PADDLE_THROW(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
namespace paddle {
namespace framework {

class CondInstruction;
class IfInstruction;
class WhileInstruction;
class ValueExecutionInfo {
public:
friend class CondInstruction;
friend class IfInstruction;
friend class WhileInstruction;

explicit ValueExecutionInfo(Scope* scope) : scope_(scope) {}
Expand Down
Loading