Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion paddle/fluid/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,27 @@ void BlockDesc::MoveFrom(BlockDesc *block) {
}
ops_.clear();
for (const auto &src_op : block->ops_) {
AppendOp()->CopyFrom(*src_op);
auto *dst_op = AppendOp();
dst_op->CopyFrom(*src_op);
for (const auto &pair : src_op->GetAttrMap()) {
const auto &attr_name = pair.first;
const auto &attr_value = pair.second;
auto attr_type = static_cast<proto::AttrType>(attr_value.which() - 1);
if (attr_type == proto::AttrType::BLOCK) {
auto block_id = BOOST_GET_CONST(BlockDesc *, attr_value)->ID();
dst_op->SetBlockAttr(attr_name, prog_->MutableBlock(block_id));
VLOG(10) << "Set block attr " << attr_name << " id " << block_id;
} else if (attr_type == proto::AttrType::BLOCKS) {
auto old_blocks = BOOST_GET_CONST(std::vector<BlockDesc *>, attr_value);
std::vector<BlockDesc *> new_blocks;
new_blocks.reserve(old_blocks.size());
for (auto *b : old_blocks) {
VLOG(10) << "Set block attr " << attr_name << " id " << b->ID();
new_blocks.push_back(prog_->MutableBlock(b->ID()));
}
dst_op->SetBlocksAttr(attr_name, new_blocks);
}
}
}
need_update_ = true;
Flush();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class MemoryReusePass : public Pass {
details::VarHandle *in_var,
details::VarHandle *out_var) const;

bool SupportApplyProgramViaGraph() const override { return false; }

private:
VarDesc *GetVarDesc(const details::VarHandle &var) const;

Expand Down
113 changes: 91 additions & 22 deletions paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"

namespace paddle {
namespace framework {
Expand All @@ -28,6 +29,9 @@ class Graph;
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

DEFINE_bool(apply_pass_to_program, false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better put it in flags.cc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"Whether to apply IR pass to program");

namespace paddle {
namespace framework {
namespace ir {
Expand Down Expand Up @@ -72,19 +76,6 @@ Graph *Pass::Apply(Graph *graph) const {
return graph;
}

void Pass::Apply(ProgramDesc *main_program,
ProgramDesc *startup_program) const {
VLOG(10) << "apply pass " << Type() << " to program";
PADDLE_ENFORCE_NOT_NULL(main_program, platform::errors::InvalidArgument(
"main program must be provided"));
PADDLE_ENFORCE_NOT_NULL(
startup_program,
platform::errors::InvalidArgument("startup program must be provided"));

ApplyImpl(main_program, startup_program);
VLOG(10) << "finish to apply pass " << Type() << " to program";
}

template <typename Container, typename Visitor>
static void VisitAllElements(Container &&container, Visitor &&visitor,
bool reverse) {
Expand All @@ -95,8 +86,8 @@ static void VisitAllElements(Container &&container, Visitor &&visitor,
}
}

void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
bool append) {
static void MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
bool append) {
PADDLE_ENFORCE_NOT_NULL(
dst, platform::errors::InvalidArgument("Dst program must be provided."));
bool reverse = !append;
Expand Down Expand Up @@ -137,27 +128,105 @@ void Pass::MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
VisitAllElements(srcs, create_op_visitor, reverse);
}

static void FillNotSpecifiedOpRole(const ProgramDesc &main_program) {
for (size_t block_idx = 0; block_idx < main_program.Size(); ++block_idx) {
auto ops = main_program.Block(block_idx).AllOps();
size_t n = ops.size();
std::vector<OpRole> roles;
roles.reserve(n);
auto op_role_attr = OpProtoAndCheckerMaker::OpRoleAttrName();
for (auto *op : ops) {
OpRole role;
if (op->HasAttr(op_role_attr)) {
role = static_cast<OpRole>(op->GetAttrIfExists<int>(op_role_attr));
} else {
role = OpRole::kNotSpecified;
}
roles.emplace_back(role);
}

// NOTE: The following codes may be wrong in some cases.
// But how can we get the right OpRole? The right way
// is that all passes should deal with unspecified OpRole.
auto prev_role = OpRole::kForward;
for (size_t i = 0; i < n; ++i) {
if (roles[i] == OpRole::kNotSpecified) {
VLOG(10) << "Fill op role of " << ops[i]->Type() << " as "
<< static_cast<int>(prev_role);
ops[i]->SetAttr(op_role_attr, static_cast<int>(prev_role));
} else {
prev_role = roles[i];
}
}
}
}

void Pass::ApplyPassesToProgram(const std::vector<const Pass *> &passes,
ProgramDesc *main_program,
ProgramDesc *startup_program) {
VLOG(10) << "ApplyPassesToProgram is called";
PADDLE_ENFORCE_NOT_NULL(
main_program,
platform::errors::InvalidArgument("The main program must be provided."));

PADDLE_ENFORCE_NOT_NULL(startup_program,
platform::errors::InvalidArgument(
"The startup program must be provided."));

for (auto *p : passes) {
PADDLE_ENFORCE_NOT_NULL(p, platform::errors::InvalidArgument(
"The provided pass cannot be nullptr."));
VLOG(10) << "Pass " << p->Type();
if (passes.size() > 1) {
PADDLE_ENFORCE_EQ(p->SupportApplyProgramViaGraph(), true,
platform::errors::PermissionDenied(
"Each pass must support to be applied via Graph if "
"multi-passes are applied."));
}
}

if (passes.size() == 1 && !passes[0]->SupportApplyProgramViaGraph()) {
VLOG(10) << "apply pass " << passes[0]->Type() << " to program";
passes[0]->ApplyImpl(main_program, startup_program);
FillNotSpecifiedOpRole(*main_program);
VLOG(10) << "finish to apply pass " << passes[0]->Type() << " to program";
return;
}

Graph graph(*main_program);
for (auto *p : passes) {
p->Apply(&graph);
}
ConvertToPrograms(&graph, main_program, startup_program);
FillNotSpecifiedOpRole(*main_program);
}

void Pass::ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const {
Graph graph(*main_program);
Apply(&graph);
PADDLE_THROW(platform::errors::Unimplemented(
"The pass %s does not support to apply ProgramDesc directly", Type()));
}

void Pass::ConvertToPrograms(Graph *graph, ProgramDesc *main_program,
ProgramDesc *startup_program) {
ProgramDesc new_main_program;
GraphToProgram(graph, &new_main_program);
GraphToProgram(*graph, &new_main_program);
main_program->CopyFrom(*new_main_program.Proto());

if (graph.Has(details::kStartupProgramDescs)) {
if (graph->Has(details::kStartupProgramDescs)) {
const auto &startups =
graph.Get<details::ProgramDescs>(details::kStartupProgramDescs);
graph->Get<details::ProgramDescs>(details::kStartupProgramDescs);
VLOG(10) << "Merge startup programs";
MergePrograms(startup_program, startups, /*append=*/true);
graph->Erase(details::kStartupProgramDescs);
}

if (graph.Has(details::kProgramDescs)) {
if (graph->Has(details::kProgramDescs)) {
const auto &mains =
graph.Get<details::ProgramDescs>(details::kProgramDescs);
graph->Get<details::ProgramDescs>(details::kProgramDescs);
VLOG(10) << "Merge main programs";
MergePrograms(main_program, mains, /*append=*/false);
graph->Erase(details::kProgramDescs);
}

startup_program->Flush();
Expand Down
12 changes: 8 additions & 4 deletions paddle/fluid/framework/ir/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ class Pass {

Graph *Apply(Graph *graph) const;

void Apply(ProgramDesc *main_program, ProgramDesc *startup_program) const;

// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
Expand Down Expand Up @@ -142,6 +140,12 @@ class Pass {
attrs_[attr_name] = attr;
}

static void ApplyPassesToProgram(const std::vector<const Pass *> &passes,
ProgramDesc *main_program,
ProgramDesc *startup_program);

virtual bool SupportApplyProgramViaGraph() const { return true; }

protected:
virtual void ApplyImpl(Graph *graph) const {
PADDLE_THROW(platform::errors::Unimplemented(
Expand All @@ -151,8 +155,8 @@ class Pass {
virtual void ApplyImpl(ProgramDesc *main_program,
ProgramDesc *startup_program) const;

static void MergePrograms(ProgramDesc *dst, const details::ProgramDescs &srcs,
bool append);
static void ConvertToPrograms(ir::Graph *graph, ProgramDesc *main_program,
ProgramDesc *startup_program);

// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
Expand Down
34 changes: 34 additions & 0 deletions paddle/fluid/framework/program_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,40 @@ namespace paddle {
namespace framework {
class VarDesc;

TEST(ProgramDesc, block_desc_move) {
auto program = std::make_unique<ProgramDesc>();
auto* global_block = program->MutableBlock(0);

auto* op = global_block->AppendOp();
op->SetType("op_with_subblock");
op->SetAttr("sub_block", program->AppendBlock(*global_block));

std::vector<BlockDesc*> sub_blocks;
sub_blocks.push_back(program->AppendBlock(*global_block));
sub_blocks.push_back(program->AppendBlock(*global_block));
op->SetAttr("sub_blocks", sub_blocks);

program->Flush();

ProgramDesc program_move;
for (size_t i = 1; i < program->Size(); ++i) {
program_move.AppendBlock(program_move.Block(0));
}
for (size_t i = 0; i < program->Size(); ++i) {
program_move.MutableBlock(i)->MoveFrom(program->MutableBlock(i));
}
program = nullptr;
EXPECT_EQ(program_move.Size(), static_cast<size_t>(4));
op = program_move.Block(0).Op(0);
auto sub_block = op->GetAttrIfExists<BlockDesc*>("sub_block");
EXPECT_EQ(sub_block, program_move.MutableBlock(1));

sub_blocks = op->GetAttrIfExists<std::vector<BlockDesc*>>("sub_blocks");
EXPECT_EQ(sub_blocks.size(), static_cast<size_t>(2));
EXPECT_EQ(sub_blocks[0], program_move.MutableBlock(2));
EXPECT_EQ(sub_blocks[1], program_move.MutableBlock(3));
}

TEST(ProgramDesc, copy_ctor) {
ProgramDesc program;
auto* global_block = program.MutableBlock(0);
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pybind/global_value_getter_setter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ DECLARE_bool(benchmark);
DECLARE_int32(inner_op_parallelism);
DECLARE_int32(max_inplace_grad_add);
DECLARE_string(tracer_profile_fname);
DECLARE_bool(apply_pass_to_program);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cudnn
DECLARE_uint64(conv_workspace_size_limit);
Expand Down Expand Up @@ -365,7 +366,8 @@ static void RegisterGlobalVarGetterSetter() {
FLAGS_memory_fraction_of_eager_deletion, FLAGS_use_pinned_memory,
FLAGS_benchmark, FLAGS_inner_op_parallelism, FLAGS_tracer_profile_fname,
FLAGS_paddle_num_threads, FLAGS_use_mkldnn, FLAGS_max_inplace_grad_add,
FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off);
FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off,
FLAGS_apply_pass_to_program);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_PUBLIC_GLOBAL_VAR(
Expand Down
64 changes: 45 additions & 19 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,19 @@ static void SetAttrsToPass(
}
}

static std::vector<std::string> GetPassNames(const py::object &names) {
try {
return {py::cast<std::string>(names)};
} catch (py::cast_error &) {
try {
return py::cast<std::vector<std::string>>(names);
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Pass names must be either str or list[str]"));
}
}
}

void BindPass(py::module *m) {
// NOTE: pass_attr_types is a dict to indicate the type of each attribute.
// Python has only one integral type "int", but C++ has many integral types.
Expand All @@ -312,25 +325,38 @@ void BindPass(py::module *m) {
REGISTER_PASS_ATTR_GETTER_SETTER("str", std::string);
REGISTER_PASS_ATTR_GETTER_SETTER("list[str]", std::vector<std::string>);

m->def(
"apply_pass",
[](framework::ProgramDesc *main_program,
framework::ProgramDesc *startup_program, const std::string &pass_name,
const std::unordered_map<std::string, py::object> &pass_attrs,
std::unordered_map<std::string, std::string> pass_attr_types) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get());
pass->Apply(main_program, startup_program);
std::unordered_map<std::string, py::object> result_attrs;
for (const auto &name_and_value : pass_attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_type = pass_attr_types.at(attr_name);
result_attrs[attr_name] =
PassAttrGetterSetterRegistry::Instance().Get(*pass, attr_name,
attr_type);
}
return result_attrs;
});
m->def("apply_pass",
[](framework::ProgramDesc *main_program,
framework::ProgramDesc *startup_program,
const py::object &py_pass_names,
const std::unordered_map<std::string, py::object> &pass_attrs,
std::unordered_map<std::string, std::string> pass_attr_types) {
auto pass_names = GetPassNames(py_pass_names);
std::vector<std::unique_ptr<framework::ir::Pass>> passes;
std::vector<const framework::ir::Pass *> passes_not_owned;
passes.reserve(pass_names.size());
passes_not_owned.reserve(pass_names.size());
for (const auto &name : pass_names) {
auto pass = framework::ir::PassRegistry::Instance().Get(name);
SetAttrsToPass(pass_attrs, &pass_attr_types, pass.get());
passes.push_back(std::move(pass));
passes_not_owned.push_back(passes.back().get());
}

framework::ir::Pass::ApplyPassesToProgram(
passes_not_owned, main_program, startup_program);
std::unordered_map<std::string, py::object> result_attrs;
for (const auto &pass : passes) {
for (const auto &name_and_value : pass_attrs) {
const auto &attr_name = name_and_value.first;
const auto &attr_type = pass_attr_types.at(attr_name);
result_attrs[attr_name] =
PassAttrGetterSetterRegistry::Instance().Get(
*pass, attr_name, attr_type);
}
}
return result_attrs;
});
}

} // namespace pybind
Expand Down
Loading