Skip to content

Commit e5fbff4

Browse files
authored
[PIR] Refine conditional_block op translator (#59723)
* refine * fix * add select_input * fix * fix * refine if_op without falseblock * fix bug * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
1 parent 217cc54 commit e5fbff4

File tree

22 files changed

+845
-294
lines changed

22 files changed

+845
-294
lines changed

paddle/fluid/framework/new_executor/instruction/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ cc_library(
33
SRCS instruction_base.cc
44
phi_kernel_instruction.cc
55
legacy_kernel_instruction.cc
6-
cond_instruction.cc
6+
if_instruction.cc
77
while_instruction.cc
8+
select_input_instruction.cc
89
has_elements_instruction.cc
910
tuple_push_instruction.cc
1011
tuple_pop_instruction.cc

paddle/fluid/framework/new_executor/instruction/cond_instruction.cc renamed to paddle/fluid/framework/new_executor/instruction/if_instruction.cc

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/new_executor/instruction/cond_instruction.h"
15+
#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h"
1616

1717
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
1818
#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
@@ -39,11 +39,11 @@
3939
namespace paddle {
4040
namespace framework {
4141

42-
CondInstruction::CondInstruction(size_t id,
43-
const platform::Place& place,
44-
pir::Operation* op,
45-
ValueExecutionInfo* value_exec_info,
46-
const std::set<std::string>& skip_gc_vars)
42+
IfInstruction::IfInstruction(size_t id,
43+
const platform::Place& place,
44+
pir::Operation* op,
45+
ValueExecutionInfo* value_exec_info,
46+
const std::set<std::string>& skip_gc_vars)
4747
: InstructionBase(id, place) {
4848
PADDLE_ENFORCE(
4949
op->isa<paddle::dialect::IfOp>(),
@@ -66,12 +66,14 @@ CondInstruction::CondInstruction(size_t id,
6666
// OpOperand of IfOp, and the other is external Values used in true_block or
6767
// false_block.
6868
auto& true_branch_block = if_op.true_block();
69-
auto& false_branch_block = if_op.false_block();
69+
7070
std::unordered_map<pir::Value, std::vector<int>> inputs;
7171
GetInputIds(op, *value_exec_info, &inputs);
7272
auto true_outside_inputs =
7373
GetExternalInputs(&true_branch_block, *value_exec_info, &inputs);
74-
auto false_outside_inputs =
74+
std::vector<pir::Value> false_outside_inputs;
75+
auto& false_branch_block = if_op.false_block();
76+
false_outside_inputs =
7577
GetExternalInputs(&false_branch_block, *value_exec_info, &inputs);
7678
SetInputs(inputs);
7779

@@ -90,8 +92,10 @@ CondInstruction::CondInstruction(size_t id,
9092
}
9193
}
9294
InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs);
95+
9396
InsertTuplePushContinerToOuts(
94-
&false_branch_block, *value_exec_info, &outputs);
97+
&if_op.false_block(), *value_exec_info, &outputs);
98+
9599
SetOutputs(outputs);
96100
VLOG(6) << "finish process inputs outputs index";
97101

@@ -126,11 +130,10 @@ CondInstruction::CondInstruction(size_t id,
126130
false_branch_inter_ =
127131
new PirInterpreter(place,
128132
{},
129-
&false_branch_block,
133+
&if_op.false_block(),
130134
false_scope,
131135
value_exec_info->NewChild(false_scope),
132136
{});
133-
134137
std::set<std::string> false_skip_gc_names_set;
135138
for (auto value : GetYiedOpInputs(&false_branch_block)) {
136139
false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value));
@@ -146,10 +149,11 @@ CondInstruction::CondInstruction(size_t id,
146149
false_skip_gc_names_set.insert(var_name);
147150
}
148151
false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set);
152+
149153
VLOG(6) << "finish process false branch interpreter";
150154
}
151155

152-
CondInstruction::~CondInstruction() {
156+
IfInstruction::~IfInstruction() {
153157
if (true_branch_inter_ != nullptr) {
154158
delete true_branch_inter_;
155159
}
@@ -158,8 +162,8 @@ CondInstruction::~CondInstruction() {
158162
}
159163
}
160164

161-
void CondInstruction::CopyBranchOutput(
162-
const std::vector<std::string>& var_names, const PirInterpreter* inter) {
165+
void IfInstruction::CopyBranchOutput(const std::vector<std::string>& var_names,
166+
const PirInterpreter* inter) {
163167
for (size_t i = 0; i < var_names.size(); ++i) {
164168
auto* inner_var = inter->InnerScope()->GetVar(var_names[i]);
165169

@@ -179,7 +183,7 @@ void CondInstruction::CopyBranchOutput(
179183
}
180184
}
181185

182-
void CondInstruction::Run() {
186+
void IfInstruction::Run() {
183187
DeviceContext().Wait();
184188
if (cond_var_->Get<phi::DenseTensor>().data<bool>()[0]) {
185189
true_branch_inter_->Run({}, false);
@@ -188,7 +192,6 @@ void CondInstruction::Run() {
188192
false_branch_inter_->Run({}, false);
189193
CopyBranchOutput(false_branch_outputs_, false_branch_inter_);
190194
}
191-
192195
// copy ouptut
193196
}
194197

paddle/fluid/framework/new_executor/instruction/cond_instruction.h renamed to paddle/fluid/framework/new_executor/instruction/if_instruction.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ class Value;
2727
class PirInterpreter;
2828
class ValueExecutionInfo;
2929

30-
class CondInstruction : public InstructionBase {
30+
class IfInstruction : public InstructionBase {
3131
public:
32-
CondInstruction(size_t id,
33-
const platform::Place& place,
34-
::pir::Operation* op,
35-
ValueExecutionInfo* value_exe_info,
36-
const std::set<std::string>& skip_gc_vars);
32+
IfInstruction(size_t id,
33+
const platform::Place& place,
34+
::pir::Operation* op,
35+
ValueExecutionInfo* value_exe_info,
36+
const std::set<std::string>& skip_gc_vars);
3737

38-
~CondInstruction();
38+
~IfInstruction();
3939

4040
void Run() override;
4141

@@ -53,15 +53,15 @@ class CondInstruction : public InstructionBase {
5353

5454
::pir::Operation* op_;
5555

56-
std::string cond_name_{"cond_instruction"};
56+
std::string cond_name_{"if_instruction"};
5757

5858
Variable* cond_var_;
5959

6060
std::vector<Variable*> output_vars_;
6161

62-
PirInterpreter* true_branch_inter_;
62+
PirInterpreter* true_branch_inter_ = nullptr;
6363

64-
PirInterpreter* false_branch_inter_;
64+
PirInterpreter* false_branch_inter_ = nullptr;
6565

6666
std::vector<std::string> true_branch_outputs_;
6767

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h"
16+
#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
17+
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
18+
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
23+
SelectInputInstruction::SelectInputInstruction(
24+
size_t id,
25+
const platform::Place &place,
26+
::pir::Operation *op,
27+
ValueExecutionInfo *value_exe_info)
28+
: InstructionBase(id, place), op_(op) {
29+
VLOG(6) << "construct select_input instruction";
30+
31+
std::unordered_map<pir::Value, std::vector<int>> inputs;
32+
mask_ = value_exe_info->GetVarByValue(op->operand_source(0));
33+
inputs.emplace(op->operand_source(0),
34+
GetValueIds(op->operand_source(0), *value_exe_info));
35+
36+
for (size_t i = 1; i < op->num_operands(); ++i) {
37+
inputs_.push_back(value_exe_info->GetVarByValue(op->operand_source(i)));
38+
inputs.emplace(op->operand_source(i),
39+
GetValueIds(op->operand_source(i), *value_exe_info));
40+
}
41+
SetInputs(inputs);
42+
43+
std::unordered_map<pir::Value, std::vector<int>> outputs;
44+
out_ = value_exe_info->GetVarByValue(op->result(0));
45+
outputs.emplace(op->result(0), GetValueIds(op->result(0), *value_exe_info));
46+
SetOutputs(outputs);
47+
}
48+
49+
inline int GetBranchNumber(const phi::DenseTensor &mask) {
50+
PADDLE_ENFORCE_EQ(
51+
mask.numel(),
52+
1,
53+
phi::errors::Fatal("The numel of Input(Mask) in SelectInputOp or "
54+
"SelectOutputOp must be 1. "
55+
"But received %d, and it's shape is [%s].",
56+
mask.numel(),
57+
mask.dims()));
58+
if (platform::is_cpu_place(mask.place())) {
59+
return mask.data<int>()[0];
60+
}
61+
// when platform::is_gpu_place(mask.place()) is true
62+
std::unique_ptr<phi::DenseTensor> cpu_mask{new phi::DenseTensor()};
63+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
64+
defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU)
65+
framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get());
66+
#else
67+
PADDLE_THROW(phi::errors::Fatal(
68+
"This version of PaddlePaddle does NOT support GPU, "
69+
"but got GPU tensor 'Mask' in SelectInputOp or SelectOutputOp. "
70+
"Please compile PaddlePaddle WITH_GPU first."));
71+
#endif
72+
return cpu_mask->data<int>()[0];
73+
}
74+
75+
class AssignFunctor {
76+
public:
77+
explicit AssignFunctor(Variable *out) : out_(out) {}
78+
79+
void operator()(const phi::DenseTensor &lod_tensor) const {
80+
auto &out_tensor = *out_->GetMutable<phi::DenseTensor>();
81+
copy_tensor(lod_tensor, &out_tensor);
82+
}
83+
84+
void operator()(const phi::TensorArray &array) const {
85+
auto &out_array = *out_->GetMutable<phi::TensorArray>();
86+
out_array.resize(array.size());
87+
for (size_t i = 0; i < array.size(); ++i) {
88+
copy_tensor(array[i], &out_array[i]);
89+
}
90+
}
91+
92+
void operator()(const phi::SelectedRows &rows) const {
93+
phi::SelectedRows &out_rows = *out_->GetMutable<phi::SelectedRows>();
94+
out_rows.set_rows(rows.rows());
95+
out_rows.set_height(rows.height());
96+
auto &t = rows.value();
97+
auto *m = out_rows.mutable_value();
98+
TensorCopy(t, t.place(), m);
99+
}
100+
101+
template <typename T>
102+
void operator()(const T &v UNUSED) const {
103+
PADDLE_ENFORCE_EQ(
104+
true,
105+
false,
106+
platform::errors::PermissionDenied(
107+
"Not support type for assign op with type %s", typeid(T).name()));
108+
}
109+
110+
private:
111+
void copy_tensor(const phi::DenseTensor &lod_tensor,
112+
phi::DenseTensor *out) const {
113+
if (!lod_tensor.IsInitialized()) return;
114+
auto &out_tensor = *out;
115+
TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor);
116+
out_tensor.set_lod(lod_tensor.lod());
117+
}
118+
119+
Variable *out_;
120+
};
121+
122+
void SelectInputInstruction::Run() {
123+
VLOG(6) << "run select_input instruction";
124+
auto &mask = mask_->Get<phi::DenseTensor>();
125+
size_t output_branch = static_cast<size_t>(GetBranchNumber(mask));
126+
PADDLE_ENFORCE_LT(
127+
output_branch,
128+
inputs_.size(),
129+
phi::errors::Fatal(
130+
"Input 'Mask' in SelectInputOp is invalid. "
131+
"'Mask' must be less than the size of input vector 'X'. "
132+
"But received Mask = %d, X's size = %d.",
133+
output_branch,
134+
inputs_.size()));
135+
Variable *selected = inputs_[output_branch];
136+
VisitVarType(*selected, AssignFunctor(out_));
137+
}
138+
139+
} // namespace framework
140+
} // namespace paddle
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
class ValueExecutionInfo;
23+
24+
class SelectInputInstruction : public InstructionBase {
25+
public:
26+
SelectInputInstruction(size_t id,
27+
const platform::Place& place,
28+
::pir::Operation* op,
29+
ValueExecutionInfo* value_exe_info);
30+
31+
void Run() override;
32+
33+
const std::string& Name() const override { return name_; }
34+
35+
::pir::Operation* Operation() const override { return op_; }
36+
37+
private:
38+
::pir::Operation* op_;
39+
40+
OpFuncType type_;
41+
42+
std::string name_{"pd_op.select_input"};
43+
44+
Variable* mask_; // not owned
45+
46+
std::vector<Variable*> inputs_; // not owned
47+
48+
Variable* out_; // not owned
49+
};
50+
51+
} // namespace framework
52+
} // namespace paddle

paddle/fluid/framework/new_executor/instruction/while_instruction.cc

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,14 @@ void WhileInstruction::CopyOutputsToBlockArgs() {
175175
auto* dst_tensor_array = inner_var->GetMutable<phi::TensorArray>();
176176
dst_tensor_array->set_type(src_tensor_array.dtype());
177177
dst_tensor_array->set_layout(src_tensor_array.layout());
178-
if (dst_tensor_array->empty()) {
179-
for (auto src_tensor : src_tensor_array) {
180-
phi::DenseTensor* tmp_dst_tensor = new phi::DenseTensor();
181-
tmp_dst_tensor->set_meta(src_tensor.meta());
182-
framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor);
183-
dst_tensor_array->push_back(*tmp_dst_tensor);
184-
}
185-
} else {
186-
for (size_t id = 0; id < dst_tensor_array->size(); id++) {
187-
auto& src_tensor = src_tensor_array[id];
188-
phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id);
189-
tmp_dst_tensor->set_meta(src_tensor.meta());
190-
framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor);
191-
}
178+
while (dst_tensor_array->size() < src_tensor_array.size()) {
179+
dst_tensor_array->emplace_back();
180+
}
181+
for (size_t id = 0; id < dst_tensor_array->size(); id++) {
182+
auto& src_tensor = src_tensor_array[id];
183+
phi::DenseTensor* tmp_dst_tensor = &dst_tensor_array->at(id);
184+
tmp_dst_tensor->set_meta(src_tensor.meta());
185+
framework::TensorCopy(src_tensor, src_tensor.place(), tmp_dst_tensor);
192186
}
193187
} else {
194188
PADDLE_THROW(

paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343
namespace paddle {
4444
namespace framework {
4545

46-
class CondInstruction;
46+
class IfInstruction;
4747
class WhileInstruction;
4848
class ValueExecutionInfo {
4949
public:
50-
friend class CondInstruction;
50+
friend class IfInstruction;
5151
friend class WhileInstruction;
5252

5353
explicit ValueExecutionInfo(Scope* scope) : scope_(scope) {}

0 commit comments

Comments
 (0)