Skip to content

Commit 8046e33

Browse files
authored
Add some passes which can be applied to Program (PaddlePaddle#34730)
* add inplace passes and tests * update * fix use_cuda undefined fix compile error of op compat * add more ut * fix CPU CI error * check adam unique * fix mac/windows ci, improve coverage * fix ci error * follow weihang's comment * fix BlockDesc::MoveFrom * follow qiuliang's comment * update * follow huihuang's comments
1 parent 5de576b commit 8046e33

File tree

18 files changed

+966
-34
lines changed

18 files changed

+966
-34
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,5 +238,41 @@ BlockDesc *BlockDesc::ForwardBlock() const {
238238
return prog_->MutableBlock(static_cast<size_t>(desc_->forward_block_idx()));
239239
}
240240

241+
void BlockDesc::MoveFrom(BlockDesc *block) {
242+
PADDLE_ENFORCE_NOT_NULL(
243+
block, platform::errors::InvalidArgument("Block must be provided."));
244+
if (this == block) {
245+
return;
246+
}
247+
248+
for (auto &pair : block->vars_) {
249+
const auto &name = pair.first;
250+
auto &var_ptr = pair.second;
251+
auto &old_var_ptr = vars_[name];
252+
if (old_var_ptr == nullptr) {
253+
VLOG(10) << "Create new variable " << var_ptr->Name();
254+
old_var_ptr = std::move(var_ptr);
255+
} else {
256+
// NOTE(zjl): cannot release old_var_ptr, because Python
257+
// Variable holds the reference of the C++ VarDesc object.
258+
// If the C++ VarDesc object is destructed, any call to the
259+
// methods of Python Variable may raise segmentation fault.
260+
VLOG(10) << "Update old variable " << var_ptr->Name();
261+
*old_var_ptr = *var_ptr;
262+
}
263+
}
264+
ops_.clear();
265+
for (const auto &src_op : block->ops_) {
266+
AppendOp()->CopyFrom(*src_op);
267+
}
268+
need_update_ = true;
269+
Flush();
270+
271+
block->ops_.clear();
272+
block->vars_.clear();
273+
block->need_update_ = true;
274+
block->Flush();
275+
}
276+
241277
} // namespace framework
242278
} // namespace paddle

paddle/fluid/framework/block_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ class BlockDesc {
111111

112112
ProgramDesc *Program() const { return this->prog_; }
113113

114+
void MoveFrom(BlockDesc *block);
115+
114116
private:
115117
ProgramDesc *prog_; // not_own
116118
proto::BlockDesc *desc_; // not_own

paddle/fluid/framework/details/build_strategy.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ struct BuildStrategy {
180180

181181
bool IsFinalized() const { return is_finalized_; }
182182

183+
void ClearFinalized() {
184+
pass_builder_ = nullptr;
185+
is_finalized_ = false;
186+
}
187+
183188
bool IsMultiDevPass(const std::string &pass_name) const;
184189

185190
// Apply the passes built by the pass_builder_. The passes will be

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ if (WITH_TESTING)
5050
endif(WITH_TESTING)
5151
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})
5252

53-
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector op_def_api)
53+
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector op_def_api pass)
5454
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
5555
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass)
5656
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)

paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_h
1010

1111
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)
1212

13-
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
13+
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper)
1414
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
1515

1616
cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass)

paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <string>
1616

1717
#include "glog/logging.h"
18+
#include "paddle/fluid/framework/executor_gc_helper.h"
1819
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
1920
#include "paddle/fluid/framework/ir/pass.h"
2021
#include "paddle/fluid/platform/enforce.h"
@@ -30,6 +31,9 @@ class BufferSharedInplaceOpPass : public MemoryReusePass {
3031
std::string ReuseType() const override { return "inplace"; }
3132

3233
void Run(Graph *graph) const override;
34+
35+
void ApplyImpl(ProgramDesc *main_program,
36+
ProgramDesc *startup_program) const override;
3337
};
3438

3539
void BufferSharedInplaceOpPass::Run(Graph *graph) const {
@@ -149,6 +153,141 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
149153
}
150154
}
151155

156+
static std::string GetFirstVarName(const OpDesc &op, const std::string &slot,
157+
bool is_input) {
158+
const auto &name_map = is_input ? op.Inputs() : op.Outputs();
159+
auto iter = name_map.find(slot);
160+
if (iter != name_map.end() && !iter->second.empty()) {
161+
return iter->second[0];
162+
}
163+
return kEmptyVarName;
164+
}
165+
166+
static std::vector<std::vector<std::pair<std::string, std::string>>>
167+
GetInplaceVars(const BlockDesc &block, bool use_cuda,
168+
const std::vector<std::string> &skip_vars) {
169+
PADDLE_ENFORCE_EQ(block.ID(), 0, platform::errors::Unimplemented(
170+
"Inplace can only perform in block 0."));
171+
// only take block 0 gc_vars
172+
const auto op_gc_vars =
173+
GetEagerDeletionCleanVars(*block.Program(), skip_vars)[0];
174+
const auto all_ops = block.AllOps();
175+
PADDLE_ENFORCE_EQ(op_gc_vars.size(), all_ops.size(),
176+
platform::errors::PermissionDenied(
177+
"GC analysis error: op number not match."));
178+
size_t n = all_ops.size();
179+
std::unordered_set<std::string> visited_vars;
180+
std::unordered_set<std::string> reused_in_vars(skip_vars.begin(),
181+
skip_vars.end());
182+
std::unordered_set<std::string> reused_out_vars(skip_vars.begin(),
183+
skip_vars.end());
184+
for (const auto *op : all_ops) {
185+
if (op->Type() == "share_buffer" || op->Type() == "share_data") {
186+
const auto &inputs = op->Input("X");
187+
const auto &outputs = op->Output("Out");
188+
reused_in_vars.insert(inputs.begin(), inputs.end());
189+
reused_out_vars.insert(outputs.begin(), outputs.end());
190+
}
191+
}
192+
193+
std::vector<std::vector<std::pair<std::string, std::string>>> result(n);
194+
for (size_t i = 0; i < n; ++i) {
195+
const auto &op = *all_ops[i];
196+
const auto &gc_vars = op_gc_vars[i];
197+
const auto inputs = op.InputArgumentNames();
198+
const auto outputs = op.OutputArgumentNames();
199+
visited_vars.insert(inputs.begin(), inputs.end());
200+
201+
auto &infer_inplace = OpInfoMap::Instance().Get(op.Type()).infer_inplace_;
202+
if (gc_vars.empty() || !infer_inplace) {
203+
visited_vars.insert(outputs.begin(), outputs.end());
204+
continue;
205+
}
206+
207+
const auto var_pair = infer_inplace(use_cuda);
208+
std::unordered_multiset<std::string> input_set(inputs.begin(),
209+
inputs.end());
210+
std::unordered_multiset<std::string> output_set(outputs.begin(),
211+
outputs.end());
212+
std::unordered_set<std::string> valid_vars;
213+
for (const auto &var : gc_vars) {
214+
if (var != kEmptyVarName && input_set.count(var) == 1 &&
215+
output_set.count(var) == 0 &&
216+
block.FindVar(var)->GetType() == proto::VarType::LOD_TENSOR) {
217+
valid_vars.insert(var);
218+
}
219+
}
220+
221+
if (valid_vars.empty()) {
222+
visited_vars.insert(outputs.begin(), outputs.end());
223+
continue;
224+
}
225+
226+
for (const auto &pair : var_pair) {
227+
const auto &input_slot = pair.first;
228+
const auto &output_slot = pair.second;
229+
auto input_var = GetFirstVarName(op, input_slot, /*is_input=*/true);
230+
if (input_var == kEmptyVarName || valid_vars.count(input_var) == 0) {
231+
continue;
232+
}
233+
auto output_var = GetFirstVarName(op, output_slot, /*is_input=*/false);
234+
if (output_var == kEmptyVarName || visited_vars.count(output_var) > 0) {
235+
continue;
236+
}
237+
auto output_var_desc = block.FindVar(output_var);
238+
if (output_var_desc == nullptr || output_var_desc->Persistable() ||
239+
output_var_desc->GetType() != proto::VarType::LOD_TENSOR) {
240+
continue;
241+
}
242+
243+
if (reused_in_vars.count(input_var) > 0 ||
244+
reused_out_vars.count(output_var) > 0) {
245+
continue;
246+
}
247+
248+
// input_var -> output_var is reusable
249+
VLOG(10) << "inplace occurs at op " << i << " " << op.Type() << ": "
250+
<< input_var << " -> " << output_var;
251+
result[i].emplace_back(input_var, output_var);
252+
reused_in_vars.insert(input_var);
253+
reused_out_vars.insert(output_var);
254+
}
255+
visited_vars.insert(outputs.begin(), outputs.end());
256+
std::sort(result[i].begin(), result[i].end());
257+
}
258+
return result;
259+
}
260+
261+
void BufferSharedInplaceOpPass::ApplyImpl(ProgramDesc *main_program,
262+
ProgramDesc *startup_program) const {
263+
bool use_cuda = Get<bool>(kUseCuda);
264+
auto skip_vars = Get<std::vector<std::string>>("mem_opt_skip_vars");
265+
266+
auto *block = main_program->MutableBlock(0);
267+
auto inplace_vars = GetInplaceVars(*block, use_cuda, skip_vars);
268+
PADDLE_ENFORCE_EQ(inplace_vars.size(), block->OpSize(),
269+
platform::errors::PermissionDenied(
270+
"Inplace analysis error: op number not match."));
271+
int64_t n = static_cast<int64_t>(inplace_vars.size());
272+
for (int64_t i = n - 1; i >= 0; --i) {
273+
if (inplace_vars[i].empty()) continue;
274+
auto *op = block->InsertOp(i);
275+
std::vector<std::string> inputs, outputs;
276+
inputs.reserve(inplace_vars[i].size());
277+
outputs.reserve(inplace_vars[i].size());
278+
for (const auto &pair : inplace_vars[i]) {
279+
inputs.push_back(pair.first);
280+
outputs.push_back(pair.second);
281+
}
282+
op->SetType("share_buffer");
283+
op->SetInput("X", inputs);
284+
op->SetOutput("Out", outputs);
285+
op->SetOutput("XOut", inputs); // add necessary dependency
286+
op->SetAttr("share_dims", std::vector<bool>(inputs.size(), false));
287+
}
288+
block->Flush();
289+
}
290+
152291
} // namespace ir
153292
} // namespace framework
154293
} // namespace paddle

0 commit comments

Comments
 (0)