From c80c5fa5f2491c4ca7e282e6e3f7746ae2faa7ba Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Tue, 14 Sep 2021 02:18:17 +0000 Subject: [PATCH 1/4] GeneratePass for Python Pass, test=develop --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/ir/CMakeLists.txt | 2 + paddle/fluid/framework/ir/generate_pass.cc | 138 ++++++++++++ paddle/fluid/framework/ir/generate_pass.h | 47 ++++ .../framework/ir/generate_pass_tester.cc | 202 ++++++++++++++++++ paddle/fluid/framework/pass_desc.proto | 39 ++++ 6 files changed, 429 insertions(+) create mode 100644 paddle/fluid/framework/ir/generate_pass.cc create mode 100644 paddle/fluid/framework/ir/generate_pass.h create mode 100644 paddle/fluid/framework/ir/generate_pass_tester.cc create mode 100644 paddle/fluid/framework/pass_desc.proto diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index f8a4d099244353..151e89740dd6f4 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -28,6 +28,7 @@ add_subdirectory(io) add_subdirectory(new_executor) #ddim lib proto_library(framework_proto SRCS framework.proto) +proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto) proto_library(op_def_proto SRCS op_def.proto DEPS framework_proto) cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto boost) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 02e8b7b237e279..53310ad9c5a9f9 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -95,6 +95,7 @@ pass_library(multihead_matmul_fuse_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(layer_norm_fuse_pass inference) +pass_library(generate_pass DEPS pass_desc_proto) if(WITH_GPU OR WITH_ROCM) pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(embedding_eltwise_layernorm_fuse_pass inference) @@ -156,6 +157,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_ cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass) cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass) cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor) +cc_test(test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto) if(WITH_GPU OR WITH_ROCM) cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc new file mode 100644 index 00000000000000..f706ffab99aa2d --- /dev/null +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/generate_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +void InitGeneratePattern(PDPattern* pattern, const proto::PassDesc& pass_desc) { + const proto::BlockDesc& block = pass_desc.pattern().blocks(0); + // 1. process Op and out Var + for (int index = 0; index < block.ops_size(); ++index) { + const proto::OpDesc& op = block.ops(index); + PDNode* op_pdnode = + pattern->NewNode(string::Sprintf("%s.%d", op.type(), index)); + op_pdnode->assert_is_op(op.type()); + for (const proto::OpDesc::Var& out : op.outputs()) { + PDNode* out_pdnode = pattern->NewNode(out.arguments(0)); + out_pdnode->AsOutput()->assert_is_op_output(op.type()); + pattern->AddEdge(op_pdnode, out_pdnode); + } + } + // 2. process in Var and out Var + for (int index = 0; index < block.ops_size(); ++index) { + const proto::OpDesc& op = block.ops(index); + PDNode* op_pdnode = + pattern->RetrieveNode(string::Sprintf("%s.%d", op.type(), index)); + for (const proto::OpDesc::Var& in : op.inputs()) { + PDNode* in_pdnode = pattern->RetrieveNode(in.arguments(0)); + if (nullptr != in_pdnode) { + // out Var used by Op in pattern is intermediate role + in_pdnode->AsIntermediate(); + } else { + in_pdnode = pattern->NewNode(in.arguments(0)); + in_pdnode->AsInput()->assert_is_op_input(op.type()); + } + // in_pdnode->assert_is_persistable_var(); + pattern->AddEdge(in_pdnode, op_pdnode); + } + } +} + +GraphPatternDetector::handle_t GetGenerateRewrite( + const PDPattern& pattern, const proto::PassDesc& pass_desc) { + GraphPatternDetector::handle_t handler = [&]( + const GraphPatternDetector::subgraph_t subgraph, Graph* graph) { + const proto::BlockDesc& block = pass_desc.replace().blocks(0); + std::unordered_set remove_nodes; + for (const auto& pdnode : pattern.nodes()) { + remove_nodes.emplace(subgraph.at(pdnode.get())); + } + std::map var_node_map; + // var_node_map from VarMap + for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { + PDNode* pd_node = pattern.RetrieveNode(var_map.pattern_var()); + Node* node = subgraph.at(pd_node); + var_node_map.insert({var_map.replace_var(), node}); + remove_nodes.erase(node); + } + for (const proto::OpDesc& op : block.ops()) { + std::vector in_nodes, out_nodes; + OpDesc op_desc; + op_desc.SetType(op.type()); + for (const proto::OpDesc::Var& in : op.inputs()) { + std::vector args; + for (const std::string& argument : in.arguments()) { + Node* in_node = nullptr; + auto iter = var_node_map.find(argument); + if (iter != var_node_map.end()) { + in_node = iter->second; + } else { + // create node + } + in_nodes.push_back(in_node); + args.push_back(in_node->Name()); + } + op_desc.SetInput(in.parameter(), args); + } + for (const proto::OpDesc::Var& out : op.outputs()) { + std::vector args; + for (const std::string& argument : out.arguments()) { + Node* out_node = nullptr; + auto iter = var_node_map.find(argument); + if (iter != var_node_map.end()) { + out_node = iter->second; + } else { + // create node + } + out_nodes.push_back(out_node); + args.push_back(out_node->Name()); + } + op_desc.SetOutput(out.parameter(), args); + } + Node* op_node = graph->CreateOpNode(&op_desc); + for (Node* node : in_nodes) { + IR_NODE_LINK_TO(node, op_node); + } + for (Node* node : out_nodes) { + IR_NODE_LINK_TO(op_node, node); + } + } + GraphSafeRemoveNodes(graph, remove_nodes); + }; + return handler; +} + +GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) + : multi_pass_desc_(multi_pass_desc) { + VerifyDesc(); +} + +void GeneratePass::ApplyImpl(Graph* graph) const { + for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { + GraphPatternDetector detector; + InitGeneratePattern(detector.mutable_pattern(), pass_desc); + detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); + } +} + +void GeneratePass::VerifyDesc() const {} + +bool GeneratePass::VerifyGraph() const { return true; } + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h new file mode 100644 index 00000000000000..28dda9a7fb1556 --- /dev/null +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -0,0 +1,47 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/pass_desc.pb.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Generate a substitute pass from protobuf. +class GeneratePass : public Pass { + public: + // from PassDesc/MultiPassDesc + explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); + virtual ~GeneratePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; + + private: + GeneratePass() = delete; + DISABLE_COPY_AND_ASSIGN(GeneratePass); + // Verify desc + void VerifyDesc() const; + // Verify graph + bool VerifyGraph() const; + + proto::MultiPassDesc multi_pass_desc_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/generate_pass_tester.cc b/paddle/fluid/framework/ir/generate_pass_tester.cc new file mode 100644 index 00000000000000..20f68f44bf04f3 --- /dev/null +++ b/paddle/fluid/framework/ir/generate_pass_tester.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/generate_pass.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +class CXXGeneratePass : public GeneratePass { + public: + CXXGeneratePass() : GeneratePass(Functor()) {} +}; + +#define REGISTER_GENERATE_PASS(pass_type, function) \ + REGISTER_PASS(pass_type, ::paddle::framework::ir::CXXGeneratePass<&function>) + +proto::MultiPassDesc generate_fc_fuse() { + proto::MultiPassDesc multi_pass_desc; + for (bool with_relu : {true, false}) { + proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs(); + proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks(); + proto::OpDesc* mul = pattern->add_ops(); + mul->set_type("mul"); + proto::OpDesc::Var* mul_x = mul->add_inputs(); + mul_x->set_parameter("X"); + mul_x->add_arguments()->assign("x"); + proto::OpDesc::Var* mul_y = mul->add_inputs(); + mul_y->set_parameter("Y"); + mul_y->add_arguments()->assign("w"); + proto::OpDesc::Var* mul_out = mul->add_outputs(); + mul_out->set_parameter("Out"); + mul_out->add_arguments()->assign("mul_out"); + proto::OpDesc* ewadd = pattern->add_ops(); + ewadd->set_type("elementwise_add"); + proto::OpDesc::Var* ewadd_x = ewadd->add_inputs(); + ewadd_x->set_parameter("X"); + ewadd_x->add_arguments()->assign("mul_out"); + proto::OpDesc::Var* ewadd_y = ewadd->add_inputs(); + ewadd_y->set_parameter("Y"); + ewadd_y->add_arguments()->assign("b"); + proto::OpDesc::Var* ewadd_out = ewadd->add_outputs(); + ewadd_out->set_parameter("Out"); + ewadd_out->add_arguments()->assign("ewadd_out"); + proto::OpDesc* relu = nullptr; + proto::BlockDesc* replace = pass_desc->mutable_pattern()->add_blocks(); + proto::OpDesc* fc = replace->add_ops(); + fc->set_type("fc"); + proto::OpDesc::Var* fc_x = fc->add_inputs(); + fc_x->set_parameter("Input"); + fc_x->add_arguments()->assign("x"); + proto::OpDesc::Var* fc_w = fc->add_inputs(); + fc_w->set_parameter("W"); + fc_w->add_arguments()->assign("w"); + proto::OpDesc::Var* fc_b = fc->add_inputs(); + fc_b->set_parameter("Bias"); + fc_b->add_arguments()->assign("b"); + proto::OpDesc::Var* fc_out = fc->add_outputs(); + fc_out->set_parameter("Output"); + fc_out->add_arguments()->assign("fc_out"); + for (const char* var : {"x", "w", "b", "fc_out"}) { + proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); + var_map->set_pattern_var(var); + var_map->set_replace_var(var); + } + proto::PassDesc::AttrMap* attr_map = pass_desc->add_attr_maps(); + attr_map->set_pattern_op_idx(0); + attr_map->set_pattern_name("x_num_col_dims"); + attr_map->set_replace_op_idx(0); + attr_map->set_replace_name("in_num_col_dims"); + if (with_relu) { + relu = pattern->add_ops(); + relu->set_type("relu"); + proto::OpDesc::Var* relu_x = relu->add_inputs(); + relu_x->set_parameter("X"); + relu_x->add_arguments()->assign("ewadd_out"); + proto::OpDesc::Var* relu_out = relu->add_outputs(); + relu_out->set_parameter("Out"); + relu_out->add_arguments()->assign("relu_out"); + pass_desc->mutable_var_maps(3)->set_pattern_var("relu_out"); + proto::OpDesc::Attr* attr = fc->add_attrs(); + attr->set_name("activation_type"); + attr->set_type(proto::AttrType::STRING); + attr->set_s("relu"); + } else { + pass_desc->mutable_var_maps(3)->set_pattern_var("ewadd_out"); + } + } + return multi_pass_desc; +} + +proto::MultiPassDesc generate_multi_add_to_addn() { + proto::MultiPassDesc multi_pass_desc; + return multi_pass_desc; +} + +proto::MultiPassDesc generate_combine_matmul() { + proto::MultiPassDesc multi_pass_desc; + return multi_pass_desc; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_GENERATE_PASS(generate_fc_fuse, + paddle::framework::ir::generate_fc_fuse); +REGISTER_GENERATE_PASS(generate_multi_add_to_addn, + paddle::framework::ir::generate_multi_add_to_addn); +REGISTER_GENERATE_PASS(generate_combine_matmul, + paddle::framework::ir::generate_combine_matmul); + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "conv2d_filters_0", {}); + AddVarToScope(param_scope, "conv2d_bias_0", {}); + AddVarToScope(param_scope, "weights_0", {}); + AddVarToScope(param_scope, "weights_1", {}); + AddVarToScope(param_scope, "bias_1", {}); + AddVarToScope(param_scope, "bias_2", {}); + return param_scope; +} + +TEST(FCFusePass, basic) { + // inputs operator output + // -------------------------------------------------------- + // (a, filters_0 bias_0) conv2d -> conv2d_out + // conv2d_out relu -> relu_out_0 + // (relu_out_0, weights_0) mul -> mul_out_0 + // (mul_out_0, bias_1) elementwise_add -> add_out_0 + // add_out_0 relu -> relu_out_1 + // (relu_out_1, weights_1) mul -> mul_out_1 + // (mul_out_1, bias_2) elementwise_add -> add_out_1 + Layers layers; + auto* a = layers.data("a"); + auto* filters_0 = layers.data("conv2d_filters_0", {}, true); + auto* bias_0 = layers.data("conv2d_bias_0", {}, true); + auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false); + auto* relu_out_0 = layers.relu(conv2d_out); + auto* weights_0 = layers.data("weights_0", {}, true); + auto* mul_out_0 = layers.mul(relu_out_0, weights_0); + auto* bias_1 = layers.data("bias_1", {}, true); + auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1); + auto* relu_out_1 = layers.relu(add_out_0); + auto* weights_1 = layers.data("weights_1", {}, true); + auto* mul_out_1 = layers.mul(relu_out_1, weights_1); + auto* bias_2 = layers.data("bias_2", {}, true); + auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1); + VLOG(4) << add_out_1; + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("generate_fc_fuse"); + int num_nodes_before = graph->Nodes().size(); + int num_mul_nodes_before = GetNumOpNodes(graph, "mul"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fc_nodes_after = GetNumOpNodes(graph, "fc"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6, + platform::errors::InvalidArgument( + "num_nodes_before=%d, num_nodes_after=%d.", + num_nodes_before, num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2, + platform::errors::InvalidArgument("num_fc_nodes_after=%d.", + num_fc_nodes_after)); + PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after, + platform::errors::InvalidArgument( + "num_mul_nodes_before=%d, num_fc_nodes_after=%d.", + num_mul_nodes_before, num_fc_nodes_after)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/pass_desc.proto b/paddle/fluid/framework/pass_desc.proto new file mode 100644 index 00000000000000..c95e40a1d25e87 --- /dev/null +++ b/paddle/fluid/framework/pass_desc.proto @@ -0,0 +1,39 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +syntax = "proto2"; + +import "framework.proto"; +package paddle.framework.proto; + +// Describes one subsitute subgraph. +message PassDesc { + message VarMap { + required string pattern_var = 1; + required string replace_var = 2; + } + message AttrMap { + required int32 pattern_op_idx = 1; + required int32 replace_op_idx = 2; + required string pattern_name = 3; + required string replace_name = 4; + } + required ProgramDesc pattern = 1; + required ProgramDesc replace = 2; + repeated VarMap var_maps = 3; + repeated AttrMap attr_maps = 4; +} + +// A series of PassDesc. +message MultiPassDesc { + optional string pass_type = 1; + repeated PassDesc pass_descs = 2; +} From f60d6eff0943a9ca3f7207b5777a69c184d2de97 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Tue, 14 Sep 2021 19:33:32 +0000 Subject: [PATCH 2/4] fix unittest error, test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + paddle/fluid/framework/ir/generate_pass.cc | 207 +++++++++++++----- paddle/fluid/framework/ir/generate_pass.h | 4 +- .../framework/ir/generate_pass_tester.cc | 206 +++++++++++++++-- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 14 ++ 6 files changed, 353 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 53310ad9c5a9f9..175bd591334126 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -96,6 +96,7 @@ pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(layer_norm_fuse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) +target_link_libraries(generate_pass pass_desc_proto) if(WITH_GPU OR WITH_ROCM) pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(embedding_eltwise_layernorm_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index f706ffab99aa2d..bd2767a509ac6f 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -20,34 +20,54 @@ namespace ir { void InitGeneratePattern(PDPattern* pattern, const proto::PassDesc& pass_desc) { const proto::BlockDesc& block = pass_desc.pattern().blocks(0); - // 1. process Op and out Var + // Traverse all operators to create subgraph. for (int index = 0; index < block.ops_size(); ++index) { const proto::OpDesc& op = block.ops(index); + // Create a PDNode for current operator. Use the index as name to avoid + // multiple operators with same type. Get a PDNode from pattern subgraph + // through index in rewrite phase. PDNode* op_pdnode = - pattern->NewNode(string::Sprintf("%s.%d", op.type(), index)); - op_pdnode->assert_is_op(op.type()); - for (const proto::OpDesc::Var& out : op.outputs()) { - PDNode* out_pdnode = pattern->NewNode(out.arguments(0)); - out_pdnode->AsOutput()->assert_is_op_output(op.type()); - pattern->AddEdge(op_pdnode, out_pdnode); + pattern->NewNode(std::to_string(index))->assert_is_op(op.type()); + // Create PDNodes for inputs of current operator. + for (const proto::OpDesc::Var& var : op.inputs()) { + for (const std::string& argument : var.arguments()) { + // The input may be the output of other operator. + PDNode* var_pdnode = pattern->RetrieveNode(argument); + if (nullptr == var_pdnode) { + var_pdnode = pattern->NewNode(argument)->AsInput(); + } else if (var_pdnode->IsOutput()) { + var_pdnode->AsIntermediate(); + } + var_pdnode->assert_is_op_input(op.type()); + pattern->AddEdge(var_pdnode, op_pdnode); + } } - } - // 2. process in Var and out Var - for (int index = 0; index < block.ops_size(); ++index) { - const proto::OpDesc& op = block.ops(index); - PDNode* op_pdnode = - pattern->RetrieveNode(string::Sprintf("%s.%d", op.type(), index)); - for (const proto::OpDesc::Var& in : op.inputs()) { - PDNode* in_pdnode = pattern->RetrieveNode(in.arguments(0)); - if (nullptr != in_pdnode) { - // out Var used by Op in pattern is intermediate role - in_pdnode->AsIntermediate(); - } else { - in_pdnode = pattern->NewNode(in.arguments(0)); - in_pdnode->AsInput()->assert_is_op_input(op.type()); + // Create PDNodes for outputs of current operator. + for (const proto::OpDesc::Var& var : op.outputs()) { + for (const std::string& argument : var.arguments()) { + // The output may be the input of other operator. + PDNode* var_pdnode = pattern->RetrieveNode(argument); + if (nullptr == var_pdnode) { + var_pdnode = pattern->NewNode(argument)->AsOutput(); + } else if (var_pdnode->IsInput()) { + var_pdnode->AsIntermediate(); + } + var_pdnode->assert_is_op_output(op.type()); + pattern->AddEdge(op_pdnode, var_pdnode); } - // in_pdnode->assert_is_persistable_var(); - pattern->AddEdge(in_pdnode, op_pdnode); + } + // Set attribute condition for current operator. + for (const proto::OpDesc::Attr& attr : op.attrs()) { + op_pdnode->assert_more([&](Node* x) { + if (x && x->IsOp()) { + OpDesc* op_desc = x->Op(); + if (op_desc->HasAttr(attr.name())) { + return GetAttrValue(attr) == op_desc->GetAttr(attr.name()); + } + return false; + } + return false; + }); } } } @@ -56,53 +76,69 @@ GraphPatternDetector::handle_t GetGenerateRewrite( const PDPattern& pattern, const proto::PassDesc& pass_desc) { GraphPatternDetector::handle_t handler = [&]( const GraphPatternDetector::subgraph_t subgraph, Graph* graph) { - const proto::BlockDesc& block = pass_desc.replace().blocks(0); - std::unordered_set remove_nodes; - for (const auto& pdnode : pattern.nodes()) { - remove_nodes.emplace(subgraph.at(pdnode.get())); + // There are some duplicate patterns. + for (auto iter : subgraph) { + if (nullptr == graph->RetrieveNode(iter.second->id())) { + VLOG(3) << "Node [" << iter.second->Name() + << "] of subgraph has been removed. So skip this optimize."; + return; + } } - std::map var_node_map; - // var_node_map from VarMap + const proto::BlockDesc& block = pass_desc.replace().blocks(0); + // `var_node_maps` record the mapping of variable to the pattern subgraph. + std::map var_node_maps; for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { - PDNode* pd_node = pattern.RetrieveNode(var_map.pattern_var()); - Node* node = subgraph.at(pd_node); - var_node_map.insert({var_map.replace_var(), node}); - remove_nodes.erase(node); + Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var())); + var_node_maps.insert({var_map.replace_var(), node}); } + // Traverse all operators to create subgraph. for (const proto::OpDesc& op : block.ops()) { - std::vector in_nodes, out_nodes; OpDesc op_desc; + std::vector in_nodes, out_nodes; op_desc.SetType(op.type()); - for (const proto::OpDesc::Var& in : op.inputs()) { - std::vector args; - for (const std::string& argument : in.arguments()) { - Node* in_node = nullptr; - auto iter = var_node_map.find(argument); - if (iter != var_node_map.end()) { - in_node = iter->second; + // Create Nodes for inputs of current operator. + for (const proto::OpDesc::Var& var : op.inputs()) { + std::vector arguments; + for (const std::string& argument : var.arguments()) { + // The input may be mapped on the operator of pattern subgraph. + Node* node = nullptr; + auto iter = var_node_maps.find(argument); + if (var_node_maps.end() == iter) { + VarDesc var_desc(patterns::UniqueKey(argument)); + node = graph->CreateVarNode(&var_desc); + var_node_maps.insert({argument, node}); } else { - // create node + node = iter->second; } - in_nodes.push_back(in_node); - args.push_back(in_node->Name()); + in_nodes.push_back(node); + arguments.push_back(node->Name()); } - op_desc.SetInput(in.parameter(), args); + op_desc.SetInput(var.parameter(), arguments); } - for (const proto::OpDesc::Var& out : op.outputs()) { - std::vector args; - for (const std::string& argument : out.arguments()) { - Node* out_node = nullptr; - auto iter = var_node_map.find(argument); - if (iter != var_node_map.end()) { - out_node = iter->second; + // Create Nodes for outputs of current operator. + for (const proto::OpDesc::Var& var : op.outputs()) { + std::vector arguments; + for (const std::string& argument : var.arguments()) { + // The output may be mapped on the operator of pattern subgraph. + Node* node = nullptr; + auto iter = var_node_maps.find(argument); + if (var_node_maps.end() == iter) { + VarDesc var_desc(patterns::UniqueKey(argument)); + node = graph->CreateVarNode(&var_desc); + var_node_maps.insert({argument, node}); } else { - // create node + node = iter->second; } - out_nodes.push_back(out_node); - args.push_back(out_node->Name()); + out_nodes.push_back(node); + arguments.push_back(node->Name()); } - op_desc.SetOutput(out.parameter(), args); + op_desc.SetOutput(var.parameter(), arguments); } + // Set attribute for current operator. + for (const proto::OpDesc::Attr& attr : op.attrs()) { + op_desc.SetAttr(attr.name(), GetAttrValue(attr)); + } + // Create a Node for current operator. Node* op_node = graph->CreateOpNode(&op_desc); for (Node* node : in_nodes) { IR_NODE_LINK_TO(node, op_node); @@ -111,11 +147,24 @@ GraphPatternDetector::handle_t GetGenerateRewrite( IR_NODE_LINK_TO(op_node, node); } } + // Remove nodes that are intermediate. + std::unordered_set remove_nodes; + for (const std::unique_ptr& pdnode : pattern.nodes()) { + remove_nodes.emplace(subgraph.at(pdnode.get())); + } + for (auto iter : var_node_maps) { + remove_nodes.erase(iter.second); + } GraphSafeRemoveNodes(graph, remove_nodes); }; return handler; } +GeneratePass::GeneratePass(const std::string& binary_str) { + multi_pass_desc_.ParseFromString(binary_str); + VerifyDesc(); +} + GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) : multi_pass_desc_(multi_pass_desc) { VerifyDesc(); @@ -126,12 +175,54 @@ void GeneratePass::ApplyImpl(Graph* graph) const { GraphPatternDetector detector; InitGeneratePattern(detector.mutable_pattern(), pass_desc); detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); + // The rewrited graph needs to be verified. Current Pass should be skipped + // if validation failed. Rewrite based on the original graph cannot + // implement rollback operation. + VerifyGraph(*graph); } } -void GeneratePass::VerifyDesc() const {} +void GeneratePass::VerifyDesc() const { + PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0, + platform::errors::InvalidArgument( + "Size of PassDesc should not be empty.")); + for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { + // Check inputs/outputs of subgraph should in `var_maps`. + std::set pattern_var_sets, replace_var_sets; + for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { + pattern_var_sets.emplace(var_map.pattern_var()); + replace_var_sets.emplace(var_map.replace_var()); + } + auto check_vars = [=](std::set* var_sets, + const proto::BlockDesc& block) { + for (const proto::OpDesc& op : block.ops()) { + for (const proto::OpDesc::Var& var : op.outputs()) { + for (const std::string& argument : var.arguments()) { + var_sets->emplace(argument); + } + } + } + for (const proto::OpDesc& op : block.ops()) { + for (const proto::OpDesc::Var& var : op.inputs()) { + for (const std::string& argument : var.arguments()) { + PADDLE_ENFORCE_NE( + var_sets->find(argument), var_sets->end(), + platform::errors::InvalidArgument( + "Subgraph of PassDesc has argument [%s] not in `var_maps`.", + argument)); + } + } + } + }; + check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0)); + check_vars(&replace_var_sets, pass_desc.replace().blocks(0)); + } +} -bool GeneratePass::VerifyGraph() const { return true; } +bool GeneratePass::VerifyGraph(const Graph& graph) { + // Return true temporarily. + return true; +} } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index 28dda9a7fb1556..8e75f5f32addf1 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -24,6 +24,8 @@ namespace ir { // Generate a substitute pass from protobuf. class GeneratePass : public Pass { public: + // from binary_str + explicit GeneratePass(const std::string& binary_str); // from PassDesc/MultiPassDesc explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); virtual ~GeneratePass() {} @@ -37,7 +39,7 @@ class GeneratePass : public Pass { // Verify desc void VerifyDesc() const; // Verify graph - bool VerifyGraph() const; + static bool VerifyGraph(const Graph& graph); proto::MultiPassDesc multi_pass_desc_; }; diff --git a/paddle/fluid/framework/ir/generate_pass_tester.cc b/paddle/fluid/framework/ir/generate_pass_tester.cc index 20f68f44bf04f3..3e2531fdb91604 100644 --- a/paddle/fluid/framework/ir/generate_pass_tester.cc +++ b/paddle/fluid/framework/ir/generate_pass_tester.cc @@ -57,7 +57,7 @@ proto::MultiPassDesc generate_fc_fuse() { ewadd_out->set_parameter("Out"); ewadd_out->add_arguments()->assign("ewadd_out"); proto::OpDesc* relu = nullptr; - proto::BlockDesc* replace = pass_desc->mutable_pattern()->add_blocks(); + proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks(); proto::OpDesc* fc = replace->add_ops(); fc->set_type("fc"); proto::OpDesc::Var* fc_x = fc->add_inputs(); @@ -70,7 +70,7 @@ proto::MultiPassDesc generate_fc_fuse() { fc_b->set_parameter("Bias"); fc_b->add_arguments()->assign("b"); proto::OpDesc::Var* fc_out = fc->add_outputs(); - fc_out->set_parameter("Output"); + fc_out->set_parameter("Out"); fc_out->add_arguments()->assign("fc_out"); for (const char* var : {"x", "w", "b", "fc_out"}) { proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); @@ -105,11 +105,120 @@ proto::MultiPassDesc generate_fc_fuse() { proto::MultiPassDesc generate_multi_add_to_addn() { proto::MultiPassDesc multi_pass_desc; + proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs(); + proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks(); + proto::OpDesc* ewadd_0 = pattern->add_ops(); + ewadd_0->set_type("elementwise_add"); + proto::OpDesc::Var* ewadd_0_x = ewadd_0->add_inputs(); + ewadd_0_x->set_parameter("X"); + ewadd_0_x->add_arguments()->assign("a"); + proto::OpDesc::Var* ewadd_0_y = ewadd_0->add_inputs(); + ewadd_0_y->set_parameter("Y"); + ewadd_0_y->add_arguments()->assign("b"); + proto::OpDesc::Var* ewadd_0_out = ewadd_0->add_outputs(); + ewadd_0_out->set_parameter("Out"); + ewadd_0_out->add_arguments()->assign("ewadd_out_0"); + proto::OpDesc* ewadd_1 = pattern->add_ops(); + ewadd_1->set_type("elementwise_add"); + proto::OpDesc::Var* ewadd_1_x = ewadd_1->add_inputs(); + ewadd_1_x->set_parameter("X"); + ewadd_1_x->add_arguments()->assign("ewadd_out_0"); + proto::OpDesc::Var* ewadd_1_y = ewadd_1->add_inputs(); + ewadd_1_y->set_parameter("Y"); + ewadd_1_y->add_arguments()->assign("c"); + proto::OpDesc::Var* ewadd_1_out = ewadd_1->add_outputs(); + ewadd_1_out->set_parameter("Out"); + ewadd_1_out->add_arguments()->assign("ewadd_out_1"); + proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks(); + proto::OpDesc* addn = replace->add_ops(); + addn->set_type("add_n"); + proto::OpDesc::Var* addn_x = addn->add_inputs(); + addn_x->set_parameter("X"); + addn_x->add_arguments()->assign("a"); + addn_x->add_arguments()->assign("b"); + addn_x->add_arguments()->assign("c"); + proto::OpDesc::Var* addn_out = addn->add_outputs(); + addn_out->set_parameter("Out"); + addn_out->add_arguments()->assign("addn_out"); + for (const char* var : {"a", "b", "c", "ewadd_out_1"}) { + proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); + var_map->set_pattern_var(var); + var_map->set_replace_var(var); + } + pass_desc->mutable_var_maps(3)->set_replace_var("addn_out"); return multi_pass_desc; } proto::MultiPassDesc generate_combine_matmul() { proto::MultiPassDesc multi_pass_desc; + proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs(); + proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks(); + proto::OpDesc* matmul_0 = pattern->add_ops(); + matmul_0->set_type("matmul"); + proto::OpDesc::Var* matmul_0_x = matmul_0->add_inputs(); + matmul_0_x->set_parameter("X"); + matmul_0_x->add_arguments()->assign("a"); + proto::OpDesc::Var* matmul_0_y = matmul_0->add_inputs(); + matmul_0_y->set_parameter("Y"); + matmul_0_y->add_arguments()->assign("b"); + proto::OpDesc::Var* matmul_0_out = matmul_0->add_outputs(); + matmul_0_out->set_parameter("Out"); + matmul_0_out->add_arguments()->assign("matmul_out_0"); + proto::OpDesc* matmul_1 = pattern->add_ops(); + matmul_1->set_type("matmul"); + proto::OpDesc::Var* matmul_1_x = matmul_1->add_inputs(); + matmul_1_x->set_parameter("X"); + matmul_1_x->add_arguments()->assign("a"); + proto::OpDesc::Var* matmul_1_y = matmul_1->add_inputs(); + matmul_1_y->set_parameter("Y"); + matmul_1_y->add_arguments()->assign("c"); + proto::OpDesc::Var* matmul_1_out = matmul_1->add_outputs(); + matmul_1_out->set_parameter("Out"); + matmul_1_out->add_arguments()->assign("matmul_out_1"); + proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks(); + proto::OpDesc* concat = replace->add_ops(); + concat->set_type("concat"); + proto::OpDesc::Var* concat_x = concat->add_inputs(); + concat_x->set_parameter("X"); + concat_x->add_arguments()->assign("b"); + concat_x->add_arguments()->assign("c"); + proto::OpDesc::Var* concat_out = concat->add_outputs(); + concat_out->set_parameter("Out"); + concat_out->add_arguments()->assign("concat_out"); + proto::OpDesc* matmul = replace->add_ops(); + matmul->set_type("matmul"); + proto::OpDesc::Var* matmul_x = matmul->add_inputs(); + matmul_x->set_parameter("X"); + matmul_x->add_arguments()->assign("a"); + proto::OpDesc::Var* matmul_y = matmul->add_inputs(); + matmul_y->set_parameter("Y"); + matmul_y->add_arguments()->assign("concat_out"); + proto::OpDesc::Var* matmul_out = matmul->add_outputs(); + matmul_out->set_parameter("Out"); + matmul_out->add_arguments()->assign("matmul_out"); + proto::OpDesc* slice_0 = replace->add_ops(); + slice_0->set_type("slice"); + proto::OpDesc::Var* slice_0_x = slice_0->add_inputs(); + slice_0_x->set_parameter("X"); + slice_0_x->add_arguments()->assign("matmul_out"); + proto::OpDesc::Var* slice_0_out = slice_0->add_outputs(); + slice_0_out->set_parameter("Out"); + slice_0_out->add_arguments()->assign("slice_out_0"); + proto::OpDesc* slice_1 = replace->add_ops(); + slice_1->set_type("slice"); + proto::OpDesc::Var* slice_1_x = slice_1->add_inputs(); + slice_1_x->set_parameter("X"); + slice_1_x->add_arguments()->assign("matmul_out"); + proto::OpDesc::Var* slice_1_out = slice_1->add_outputs(); + slice_1_out->set_parameter("Out"); + slice_1_out->add_arguments()->assign("slice_out_1"); + for (const char* var : {"a", "b", "c", "matmul_out_0", "matmul_out_1"}) { + proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); + var_map->set_pattern_var(var); + var_map->set_replace_var(var); + } + pass_desc->mutable_var_maps(3)->set_replace_var("slice_out_0"); + pass_desc->mutable_var_maps(4)->set_replace_var("slice_out_1"); return multi_pass_desc; } @@ -128,25 +237,7 @@ namespace paddle { namespace framework { namespace ir { -void AddVarToScope(Scope* param_scope, const std::string& name, - const DDim& dims) { - auto* tensor = param_scope->Var(name)->GetMutable(); - tensor->Resize(dims); - tensor->mutable_data(platform::CPUPlace()); -} - -Scope* CreateParamScope() { - auto param_scope = new Scope(); - AddVarToScope(param_scope, "conv2d_filters_0", {}); - AddVarToScope(param_scope, "conv2d_bias_0", {}); - AddVarToScope(param_scope, "weights_0", {}); - AddVarToScope(param_scope, "weights_1", {}); - AddVarToScope(param_scope, "bias_1", {}); - AddVarToScope(param_scope, "bias_2", {}); - return param_scope; -} - -TEST(FCFusePass, basic) { +TEST(GeneratePass, generate_fc_fuse) { // inputs operator output // -------------------------------------------------------- // (a, filters_0 bias_0) conv2d -> conv2d_out @@ -197,6 +288,79 @@ TEST(FCFusePass, basic) { num_mul_nodes_before, num_fc_nodes_after)); } +TEST(GeneratePass, generate_multi_add_to_addn) { + // inputs operator output + // -------------------------------------------------------- + // (a, b) elementwise_add -> add_out_0 + // (add_out_0, c) elementwise_add -> add_out_1 + Layers layers; + auto* a = layers.data("a"); + auto* b = layers.data("b"); + auto* c = layers.data("c"); + auto* add_out_0 = layers.elementwise_add(a, b); + layers.elementwise_add(add_out_0, c); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("generate_multi_add_to_addn"); + int num_nodes_before = graph->Nodes().size(); + int num_add_nodes_before = GetNumOpNodes(graph, "elementwise_add"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_addn_nodes_after = GetNumOpNodes(graph, "add_n"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 2, + platform::errors::InvalidArgument( + "num_nodes_before=%d, num_nodes_after=%d.", + num_nodes_before, num_nodes_after)); + PADDLE_ENFORCE_EQ(num_addn_nodes_after, 1, + platform::errors::InvalidArgument( + "num_addn_nodes_after=%d.", num_addn_nodes_after)); + PADDLE_ENFORCE_EQ(num_add_nodes_before, num_addn_nodes_after + 1, + platform::errors::InvalidArgument( + "num_add_nodes_before=%d, num_addn_nodes_after=%d.", + num_add_nodes_before, num_addn_nodes_after)); +} + +TEST(GeneratePass, generate_combine_matmul) { + // inputs operator output + // -------------------------------------------------------- + // (a, b) matmul -> matmul_out_0 + // (a, c) matmul -> matmul_out_1 + Layers layers; + auto* a = layers.data("a"); + auto* b = layers.data("b"); + auto* c = layers.data("c"); + layers.matmul(a, b); + layers.matmul(a, c); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("generate_combine_matmul"); + int num_nodes_before = graph->Nodes().size(); + int num_matmul_nodes_before = GetNumOpNodes(graph, "matmul"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_matmul_nodes_after = GetNumOpNodes(graph, "matmul"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after - 4, + platform::errors::InvalidArgument( + "num_nodes_before=%d, num_nodes_after=%d.", + num_nodes_before, num_nodes_after)); + PADDLE_ENFORCE_EQ(num_matmul_nodes_after, 1, + platform::errors::InvalidArgument( + "num_matmul_nodes_after=%d.", num_matmul_nodes_after)); + PADDLE_ENFORCE_EQ( + num_matmul_nodes_before, num_matmul_nodes_after + 1, + platform::errors::InvalidArgument( + "num_matmul_nodes_before=%d, num_matmul_nodes_after=%d.", + num_matmul_nodes_before, num_matmul_nodes_after)); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index a4ad9333163378..4ca46758838353 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -4,7 +4,7 @@ include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) include_directories(${PADDLE_SOURCE_DIR}/paddle/utils) set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune - feed_fetch_method pass pass_builder parallel_executor profiler layer tracer engine scope_pool + feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index b0148e50afc548..41a4fa18f7fee0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -38,6 +38,7 @@ limitations under the License. */ #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h" +#include "paddle/fluid/framework/ir/generate_pass.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -2324,6 +2325,19 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); + m.def("register_pass", [](const std::string &pass_type, py::object callable) { + PADDLE_ENFORCE_EQ( + framework::ir::PassRegistry::Instance().Has(pass_type), false, + platform::errors::AlreadyExists( + "Pass '%s' is registered more than once.", pass_type)); + framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type, + callable]() { + py::gil_scoped_acquire guard; + std::unique_ptr pass( + new framework::ir::GeneratePass(py::cast(callable()))); + return pass; + }); + }); m.def("get_pass", [](const std::string &pass_type) { auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); return std::shared_ptr(std::move(pass)); From 492b3ec5b280a2140fe06997487cb101ab4bc306 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Thu, 16 Sep 2021 07:41:40 +0000 Subject: [PATCH 3/4] modify according to review comments, test=develop --- paddle/fluid/framework/ir/generate_pass.cc | 4 ++-- paddle/fluid/pybind/pybind.cc | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index bd2767a509ac6f..9eba6fc89a2e96 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -18,7 +18,7 @@ namespace paddle { namespace framework { namespace ir { -void InitGeneratePattern(PDPattern* pattern, const proto::PassDesc& pass_desc) { +void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { const proto::BlockDesc& block = pass_desc.pattern().blocks(0); // Traverse all operators to create subgraph. for (int index = 0; index < block.ops_size(); ++index) { @@ -173,7 +173,7 @@ GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) void GeneratePass::ApplyImpl(Graph* graph) const { for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { GraphPatternDetector detector; - InitGeneratePattern(detector.mutable_pattern(), pass_desc); + InitGeneratePattern(pass_desc, detector.mutable_pattern()); detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); // The rewrited graph needs to be verified. Current Pass should be skipped // if validation failed. Rewrite based on the original graph cannot diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3288774118773b..d9f0bd3c64263c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2346,11 +2346,13 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); - m.def("register_pass", [](const std::string &pass_type, py::object callable) { + m.def("register_pass", [](const std::string &pass_type, + const py::object &callable) { PADDLE_ENFORCE_EQ( framework::ir::PassRegistry::Instance().Has(pass_type), false, platform::errors::AlreadyExists( - "Pass '%s' is registered more than once.", pass_type)); + "Pass '%s' is registered more than once. Please use another name.", + pass_type)); framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type, callable]() { py::gil_scoped_acquire guard; From 76e1c91be534b8a7c249dc2a3a5533bf6dfc6a13 Mon Sep 17 00:00:00 2001 From: Avin0323 Date: Thu, 16 Sep 2021 07:55:43 +0000 Subject: [PATCH 4/4] add unittest to increase coverage, test=develop --- paddle/fluid/framework/ir/generate_pass.h | 1 - paddle/fluid/framework/ir/generate_pass_tester.cc | 10 ++++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index 8e75f5f32addf1..f73173233aed32 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -28,7 +28,6 @@ class GeneratePass : public Pass { explicit GeneratePass(const std::string& binary_str); // from PassDesc/MultiPassDesc explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); - virtual ~GeneratePass() {} protected: void ApplyImpl(Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/generate_pass_tester.cc b/paddle/fluid/framework/ir/generate_pass_tester.cc index 3e2531fdb91604..c3852d29c308ff 100644 --- a/paddle/fluid/framework/ir/generate_pass_tester.cc +++ b/paddle/fluid/framework/ir/generate_pass_tester.cc @@ -34,6 +34,8 @@ proto::MultiPassDesc generate_fc_fuse() { for (bool with_relu : {true, false}) { proto::PassDesc* pass_desc = multi_pass_desc.add_pass_descs(); proto::BlockDesc* pattern = pass_desc->mutable_pattern()->add_blocks(); + pattern->set_idx(0); + pattern->set_parent_idx(0); proto::OpDesc* mul = pattern->add_ops(); mul->set_type("mul"); proto::OpDesc::Var* mul_x = mul->add_inputs(); @@ -58,6 +60,8 @@ proto::MultiPassDesc generate_fc_fuse() { ewadd_out->add_arguments()->assign("ewadd_out"); proto::OpDesc* relu = nullptr; proto::BlockDesc* replace = pass_desc->mutable_replace()->add_blocks(); + replace->set_idx(0); + replace->set_parent_idx(0); proto::OpDesc* fc = replace->add_ops(); fc->set_type("fc"); proto::OpDesc::Var* fc_x = fc->add_inputs(); @@ -237,6 +241,12 @@ namespace paddle { namespace framework { namespace ir { +TEST(GeneratePass, construct_with_string) { + std::string binary_str; + generate_fc_fuse().SerializeToString(&binary_str); + GeneratePass generate_pass(binary_str); +} + TEST(GeneratePass, generate_fc_fuse) { // inputs operator output // --------------------------------------------------------