Skip to content

Commit 6d3a29d

Browse files
sneaxiyAnnaTrainingG
authored andcommitted
Enable program passes on Fleet APIs (PaddlePaddle#34955)
* add fleet api for program pass * turn on apply pass for CI test * fix disable fuse_all_optimizer bug * try to test ci * fix CI * fill unspecified op role * fix fuse_allreduce * add ut to improve coverage * remove useless change * improve c++ coverage * follow some comments * test ir pass pipeline * update doc * reduce ut time again
1 parent 3e4b32e commit 6d3a29d

File tree

19 files changed

+523
-60
lines changed

19 files changed

+523
-60
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,27 @@ void BlockDesc::MoveFrom(BlockDesc *block) {
263263
}
264264
ops_.clear();
265265
for (const auto &src_op : block->ops_) {
266-
AppendOp()->CopyFrom(*src_op);
266+
auto *dst_op = AppendOp();
267+
dst_op->CopyFrom(*src_op);
268+
for (const auto &pair : src_op->GetAttrMap()) {
269+
const auto &attr_name = pair.first;
270+
const auto &attr_value = pair.second;
271+
auto attr_type = static_cast<proto::AttrType>(attr_value.which() - 1);
272+
if (attr_type == proto::AttrType::BLOCK) {
273+
auto block_id = BOOST_GET_CONST(BlockDesc *, attr_value)->ID();
274+
dst_op->SetBlockAttr(attr_name, prog_->MutableBlock(block_id));
275+
VLOG(10) << "Set block attr " << attr_name << " id " << block_id;
276+
} else if (attr_type == proto::AttrType::BLOCKS) {
277+
auto old_blocks = BOOST_GET_CONST(std::vector<BlockDesc *>, attr_value);
278+
std::vector<BlockDesc *> new_blocks;
279+
new_blocks.reserve(old_blocks.size());
280+
for (auto *b : old_blocks) {
281+
VLOG(10) << "Set block attr " << attr_name << " id " << b->ID();
282+
new_blocks.push_back(prog_->MutableBlock(b->ID()));
283+
}
284+
dst_op->SetBlocksAttr(attr_name, new_blocks);
285+
}
286+
}
267287
}
268288
need_update_ = true;
269289
Flush();

paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class MemoryReusePass : public Pass {
113113
details::VarHandle *in_var,
114114
details::VarHandle *out_var) const;
115115

116+
bool SupportApplyProgramViaGraph() const override { return false; }
117+
116118
private:
117119
VarDesc *GetVarDesc(const details::VarHandle &var) const;
118120

paddle/fluid/framework/ir/pass.cc

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include <algorithm>
1818
#include "paddle/fluid/framework/ir/graph_helper.h"
19+
#include "paddle/fluid/framework/op_proto_maker.h"
1920

2021
namespace paddle {
2122
namespace framework {
@@ -72,19 +73,6 @@ Graph *Pass::Apply(Graph *graph) const {
7273
return graph;
7374
}
7475

75-
void Pass::Apply(ProgramDesc *main_program,
76-
ProgramDesc *startup_program) const {
77-
VLOG(10) << "apply pass " << Type() << " to program";
78-
PADDLE_ENFORCE_NOT_NULL(main_program, platform::errors::InvalidArgument(
79-
"main program must be provided"));
80-
PADDLE_ENFORCE_NOT_NULL(
81-
startup_program,
82-
platform::errors::InvalidArgument("startup program must be provided"));
83-
84-
ApplyImpl(main_program, startup_program);
85-
VLOG(10) << "finish to apply pass " << Type() << " to program";
86-
}
87-
8876
template <typename Container, typename Visitor>
8977
static void VisitAllElements(Container &&container, Visitor &&visitor,
9078
bool reverse) {
@@ -95,8 +83,8 @@ static void VisitAllElements(Container &&container, Visitor &&visitor,
9583
}
9684
}
9785

98-
void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
99-
bool append) {
86+
static void MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
87+
bool append) {
10088
PADDLE_ENFORCE_NOT_NULL(
10189
dst, platform::errors::InvalidArgument("Dst program must be provided."));
10290
bool reverse = !append;
@@ -137,27 +125,105 @@ void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
137125
VisitAllElements(srcs, create_op_visitor, reverse);
138126
}
139127

128+
static void FillNotSpecifiedOpRole(const ProgramDesc &main_program) {
129+
for (size_t block_idx = 0; block_idx < main_program.Size(); ++block_idx) {
130+
auto ops = main_program.Block(block_idx).AllOps();
131+
size_t n = ops.size();
132+
std::vector<OpRole> roles;
133+
roles.reserve(n);
134+
auto op_role_attr = OpProtoAndCheckerMaker::OpRoleAttrName();
135+
for (auto *op : ops) {
136+
OpRole role;
137+
if (op->HasAttr(op_role_attr)) {
138+
role = static_cast<OpRole>(op->GetAttrIfExists<int>(op_role_attr));
139+
} else {
140+
role = OpRole::kNotSpecified;
141+
}
142+
roles.emplace_back(role);
143+
}
144+
145+
// NOTE: The following codes may be wrong in some cases.
146+
// But how can we get the right OpRole? The right way
147+
// is that all passes should deal with unspecified OpRole.
148+
auto prev_role = OpRole::kForward;
149+
for (size_t i = 0; i < n; ++i) {
150+
if (roles[i] == OpRole::kNotSpecified) {
151+
VLOG(10) << "Fill op role of " << ops[i]->Type() << " as "
152+
<< static_cast<int>(prev_role);
153+
ops[i]->SetAttr(op_role_attr, static_cast<int>(prev_role));
154+
} else {
155+
prev_role = roles[i];
156+
}
157+
}
158+
}
159+
}
160+
161+
void Pass::ApplyPassesToProgram(const std::vector<const Pass *> &passes,
162+
ProgramDesc *main_program,
163+
ProgramDesc *startup_program) {
164+
VLOG(10) << "ApplyPassesToProgram is called";
165+
PADDLE_ENFORCE_NOT_NULL(
166+
main_program,
167+
platform::errors::InvalidArgument("The main program must be provided."));
168+
169+
PADDLE_ENFORCE_NOT_NULL(startup_program,
170+
platform::errors::InvalidArgument(
171+
"The startup program must be provided."));
172+
173+
for (auto *p : passes) {
174+
PADDLE_ENFORCE_NOT_NULL(p, platform::errors::InvalidArgument(
175+
"The provided pass cannot be nullptr."));
176+
VLOG(10) << "Pass " << p->Type();
177+
if (passes.size() > 1) {
178+
PADDLE_ENFORCE_EQ(p->SupportApplyProgramViaGraph(), true,
179+
platform::errors::PermissionDenied(
180+
"Each pass must support to be applied via Graph if "
181+
"multi-passes are applied."));
182+
}
183+
}
184+
185+
if (passes.size() == 1 && !passes[0]->SupportApplyProgramViaGraph()) {
186+
VLOG(10) << "apply pass " << passes[0]->Type() << " to program";
187+
passes[0]->ApplyImpl(main_program, startup_program);
188+
FillNotSpecifiedOpRole(*main_program);
189+
VLOG(10) << "finish to apply pass " << passes[0]->Type() << " to program";
190+
return;
191+
}
192+
193+
Graph graph(*main_program);
194+
for (auto *p : passes) {
195+
p->Apply(&graph);
196+
}
197+
ConvertToPrograms(&graph, main_program, startup_program);
198+
FillNotSpecifiedOpRole(*main_program);
199+
}
200+
140201
void Pass::ApplyImpl(ProgramDesc *main_program,
141202
ProgramDesc *startup_program) const {
142-
Graph graph(*main_program);
143-
Apply(&graph);
203+
PADDLE_THROW(platform::errors::Unimplemented(
204+
"The pass %s does not support to apply ProgramDesc directly", Type()));
205+
}
144206

207+
void Pass::ConvertToPrograms(Graph *graph, ProgramDesc *main_program,
208+
ProgramDesc *startup_program) {
145209
ProgramDesc new_main_program;
146-
GraphToProgram(graph, &new_main_program);
210+
GraphToProgram(*graph, &new_main_program);
147211
main_program->CopyFrom(*new_main_program.Proto());
148212

149-
if (graph.Has(details::kStartupProgramDescs)) {
213+
if (graph->Has(details::kStartupProgramDescs)) {
150214
const auto &startups =
151-
graph.Get<details::ProgramDescs>(details::kStartupProgramDescs);
215+
graph->Get<details::ProgramDescs>(details::kStartupProgramDescs);
152216
VLOG(10) << "Merge startup programs";
153217
MergePrograms(startup_program, startups, /*append=*/true);
218+
graph->Erase(details::kStartupProgramDescs);
154219
}
155220

156-
if (graph.Has(details::kProgramDescs)) {
221+
if (graph->Has(details::kProgramDescs)) {
157222
const auto &mains =
158-
graph.Get<details::ProgramDescs>(details::kProgramDescs);
223+
graph->Get<details::ProgramDescs>(details::kProgramDescs);
159224
VLOG(10) << "Merge main programs";
160225
MergePrograms(main_program, mains, /*append=*/false);
226+
graph->Erase(details::kProgramDescs);
161227
}
162228

163229
startup_program->Flush();

paddle/fluid/framework/ir/pass.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ class Pass {
6565

6666
Graph *Apply(Graph *graph) const;
6767

68-
void Apply(ProgramDesc *main_program, ProgramDesc *startup_program) const;
69-
7068
// Get a reference to the attributed previously set.
7169
template <typename AttrType>
7270
AttrType &Get(const std::string &attr_name) const {
@@ -142,6 +140,12 @@ class Pass {
142140
attrs_[attr_name] = attr;
143141
}
144142

143+
static void ApplyPassesToProgram(const std::vector<const Pass *> &passes,
144+
ProgramDesc *main_program,
145+
ProgramDesc *startup_program);
146+
147+
virtual bool SupportApplyProgramViaGraph() const { return true; }
148+
145149
protected:
146150
virtual void ApplyImpl(Graph *graph) const {
147151
PADDLE_THROW(platform::errors::Unimplemented(
@@ -151,8 +155,8 @@ class Pass {
151155
virtual void ApplyImpl(ProgramDesc *main_program,
152156
ProgramDesc *startup_program) const;
153157

154-
static void MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
155-
bool append);
158+
static void ConvertToPrograms(ir::Graph *graph, ProgramDesc *main_program,
159+
ProgramDesc *startup_program);
156160

157161
// Some Pass must be placed before this Pass, and some
158162
// Pass must be placed after this Pass.

paddle/fluid/framework/program_desc_test.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,40 @@ namespace paddle {
2323
namespace framework {
2424
class VarDesc;
2525

26+
TEST(ProgramDesc, block_desc_move) {
27+
auto program = std::make_unique<ProgramDesc>();
28+
auto* global_block = program->MutableBlock(0);
29+
30+
auto* op = global_block->AppendOp();
31+
op->SetType("op_with_subblock");
32+
op->SetAttr("sub_block", program->AppendBlock(*global_block));
33+
34+
std::vector<BlockDesc*> sub_blocks;
35+
sub_blocks.push_back(program->AppendBlock(*global_block));
36+
sub_blocks.push_back(program->AppendBlock(*global_block));
37+
op->SetAttr("sub_blocks", sub_blocks);
38+
39+
program->Flush();
40+
41+
ProgramDesc program_move;
42+
for (size_t i = 1; i < program->Size(); ++i) {
43+
program_move.AppendBlock(program_move.Block(0));
44+
}
45+
for (size_t i = 0; i < program->Size(); ++i) {
46+
program_move.MutableBlock(i)->MoveFrom(program->MutableBlock(i));
47+
}
48+
program = nullptr;
49+
EXPECT_EQ(program_move.Size(), static_cast<size_t>(4));
50+
op = program_move.Block(0).Op(0);
51+
auto sub_block = op->GetAttrIfExists<BlockDesc*>("sub_block");
52+
EXPECT_EQ(sub_block, program_move.MutableBlock(1));
53+
54+
sub_blocks = op->GetAttrIfExists<std::vector<BlockDesc*>>("sub_blocks");
55+
EXPECT_EQ(sub_blocks.size(), static_cast<size_t>(2));
56+
EXPECT_EQ(sub_blocks[0], program_move.MutableBlock(2));
57+
EXPECT_EQ(sub_blocks[1], program_move.MutableBlock(3));
58+
}
59+
2660
TEST(ProgramDesc, copy_ctor) {
2761
ProgramDesc program;
2862
auto* global_block = program.MutableBlock(0);

paddle/fluid/platform/flags.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,3 +624,16 @@ DEFINE_bool(conv2d_disable_cudnn, false, "Disable cudnn in conv2d");
624624
DEFINE_int32(get_host_by_name_time, 120,
625625
"The maximum time for get host by name time");
626626
#endif
627+
628+
/**
629+
* Distributed related FLAG
630+
* Name: FLAGS_apply_pass_to_program
631+
* Since Version: 2.2.0
632+
* Value Range: bool, default=false
633+
* Example: FLAGS_apply_pass_to_program=true would apply IR Pass to
634+
* program when using Fleet APIs.
635+
* Note: Apply IR pass to program. Be only useful when using Fleet APIs.
636+
*/
637+
DEFINE_bool(
638+
apply_pass_to_program, false,
639+
"It controls whether to apply IR pass to program when using Fleet APIs");

paddle/fluid/pybind/global_value_getter_setter.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ DECLARE_bool(benchmark);
6767
DECLARE_int32(inner_op_parallelism);
6868
DECLARE_int32(max_inplace_grad_add);
6969
DECLARE_string(tracer_profile_fname);
70+
DECLARE_bool(apply_pass_to_program);
7071
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
7172
// cudnn
7273
DECLARE_uint64(conv_workspace_size_limit);
@@ -367,7 +368,8 @@ static void RegisterGlobalVarGetterSetter() {
367368
FLAGS_memory_fraction_of_eager_deletion, FLAGS_use_pinned_memory,
368369
FLAGS_benchmark, FLAGS_inner_op_parallelism, FLAGS_tracer_profile_fname,
369370
FLAGS_paddle_num_threads, FLAGS_use_mkldnn, FLAGS_max_inplace_grad_add,
370-
FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off);
371+
FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off,
372+
FLAGS_apply_pass_to_program);
371373

372374
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
373375
REGISTER_PUBLIC_GLOBAL_VAR(

paddle/fluid/pybind/ir.cc

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,19 @@ static void SetAttrsToPass(
293293
}
294294
}
295295

296+
static std::vector<std::string> GetPassNames(const py::object &names) {
297+
try {
298+
return {py::cast<std::string>(names)};
299+
} catch (py::cast_error &) {
300+
try {
301+
return py::cast<std::vector<std::string>>(names);
302+
} catch (py::cast_error &) {
303+
PADDLE_THROW(platform::errors::InvalidArgument(
304+
"Pass names must be either str or list[str]"));
305+
}
306+
}
307+
}
308+
296309
void BindPass(py::module *m) {
297310
// NOTE: pass_attr_types is a dict to indicate the type of each attribute.
298311
// Python has only one integral type "int", but C++ has many integral types.
@@ -312,25 +325,38 @@ void BindPass(py::module *m) {
312325
REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string);
313326
REGISTER_PASS_ATTR_GETTER_SETTER("list[str]", std::vector<std::string>);
314327

315-
m->def(
316-
"apply_pass",
317-
[](framework::ProgramDesc *main_program,
318-
framework::ProgramDesc *startup_program, const std::string &pass_name,
319-
const std::unordered_map<std::string, py::object> &pass_attrs,
320-
std::unordered_map<std::string, std::string> pass_attr_types) {
321-
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
322-
SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get());
323-
pass->Apply(main_program, startup_program);
324-
std::unordered_map<std::string, py::object> result_attrs;
325-
for (const auto &name_and_value : pass_attrs) {
326-
const auto &attr_name = name_and_value.first;
327-
const auto &attr_type = pass_attr_types.at(attr_name);
328-
result_attrs[attr_name] =
329-
PassAttrGetterSetterRegistry::Instance().Get(*pass, attr_name,
330-
attr_type);
331-
}
332-
return result_attrs;
333-
});
328+
m->def("apply_pass",
329+
[](framework::ProgramDesc *main_program,
330+
framework::ProgramDesc *startup_program,
331+
const py::object &py_pass_names,
332+
const std::unordered_map<std::string, py::object> &pass_attrs,
333+
std::unordered_map<std::string, std::string> pass_attr_types) {
334+
auto pass_names = GetPassNames(py_pass_names);
335+
std::vector<std::unique_ptr<framework::ir::Pass>> passes;
336+
std::vector<const framework::ir::Pass *> passes_not_owned;
337+
passes.reserve(pass_names.size());
338+
passes_not_owned.reserve(pass_names.size());
339+
for (const auto &name : pass_names) {
340+
auto pass = framework::ir::PassRegistry::Instance().Get(name);
341+
SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get());
342+
passes.push_back(std::move(pass));
343+
passes_not_owned.push_back(passes.back().get());
344+
}
345+
346+
framework::ir::Pass::ApplyPassesToProgram(
347+
passes_not_owned, main_program, startup_program);
348+
std::unordered_map<std::string, py::object> result_attrs;
349+
for (const auto &pass : passes) {
350+
for (const auto &name_and_value : pass_attrs) {
351+
const auto &attr_name = name_and_value.first;
352+
const auto &attr_type = pass_attr_types.at(attr_name);
353+
result_attrs[attr_name] =
354+
PassAttrGetterSetterRegistry::Instance().Get(
355+
*pass, attr_name, attr_type);
356+
}
357+
}
358+
return result_attrs;
359+
});
334360
}
335361

336362
} // namespace pybind

0 commit comments

Comments
 (0)