Skip to content

Commit a7deb55

Browse files
committed
add the fc fuse example for pass enhance, test=develop
1 parent dbc08d6 commit a7deb55

File tree

13 files changed

+501
-25
lines changed

13 files changed

+501
-25
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ add_subdirectory(fleet)
2727
add_subdirectory(io)
2828
#ddim lib
2929
proto_library(framework_proto SRCS framework.proto)
30+
3031
proto_library(op_def_proto SRCS op_def.proto)
32+
set(OP_DEF_FOLDER "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/compat/")
33+
configure_file("op_def_api.h.in" "op_def_api.h")
34+
cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto)
35+
3136
proto_library(heter_service_proto SRCS heter_service.proto)
3237
proto_library(data_feed_proto SRCS data_feed.proto)
3338
proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ if (WITH_TESTING)
5050
endif(WITH_TESTING)
5151
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})
5252

53-
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector)
53+
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector op_def_api)
5454
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
5555
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass)
5656
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
16-
1716
#include <string>
17+
#include "paddle/fluid/framework/op_proto_maker.h"
1818

1919
#include "paddle/fluid/framework/op_version_registry.h"
2020
#include "paddle/fluid/platform/enforce.h"
@@ -23,6 +23,65 @@ namespace paddle {
2323
namespace framework {
2424
namespace ir {
2525

26+
FCFusePass::FCFusePass() {
27+
AddOpCompat(OpCompat("mul"))
28+
.AddInput("X")
29+
.IsTensor()
30+
.End()
31+
.AddInput("Y")
32+
.IsTensor()
33+
.End()
34+
.AddOutput("Out")
35+
.IsTensor()
36+
.End()
37+
.AddAttr("x_num_col_dims")
38+
.IsNumGE(1)
39+
.End()
40+
.AddAttr("y_num_col_dims")
41+
.End();
42+
43+
AddOpCompat(OpCompat("elementwise_add"))
44+
.AddInput("X")
45+
.IsTensor()
46+
.End()
47+
.AddInput("Y")
48+
.IsTensor()
49+
.End()
50+
.AddOutput("Out")
51+
.IsTensor()
52+
.End()
53+
.AddAttr("axis")
54+
.End();
55+
56+
AddOpCompat(OpCompat("relu"))
57+
.AddInput("X")
58+
.IsTensor()
59+
.End()
60+
.AddOutput("Out")
61+
.IsTensor()
62+
.End();
63+
64+
AddOpCompat(OpCompat("fc"))
65+
.AddInput("Input")
66+
.IsTensor()
67+
.End()
68+
.AddInput("W")
69+
.IsTensor()
70+
.End()
71+
.AddInput("Bias")
72+
.IsTensor()
73+
.End()
74+
.AddOutput("Out")
75+
.IsTensor()
76+
.End()
77+
.AddAttr("in_num_col_dims")
78+
.IsNumGE(1)
79+
.End()
80+
.AddAttr("activation_type")
81+
.IsStringIn({"relu", ""})
82+
.End();
83+
}
84+
2685
void FCFusePass::ApplyImpl(ir::Graph* graph) const {
2786
PADDLE_ENFORCE_NOT_NULL(
2887
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
@@ -52,6 +111,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
52111
LOG(WARNING) << "The subgraph is empty.";
53112
return;
54113
}
114+
if (!IsCompat(subgraph, g)) {
115+
LOG(WARNING) << "Pass in op compat failed.";
116+
return;
117+
}
55118

56119
VLOG(4) << "handle FC fuse";
57120
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
@@ -159,6 +222,11 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
159222
}
160223
desc.Flush();
161224

225+
if (!IsCompat(desc)) {
226+
LOG(WARNING) << "Fc fuse pass in out fc op compat failed.";
227+
return;
228+
}
229+
162230
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
163231
if (with_relu) {
164232
GraphSafeRemoveNodes(

paddle/fluid/framework/ir/fc_fuse_pass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Graph;
3030

3131
class FCFusePass : public FusePassBase {
3232
public:
33+
FCFusePass();
3334
virtual ~FCFusePass() {}
3435

3536
protected:

paddle/fluid/framework/ir/op_compat_sensible_pass.cc

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include <memory>
16-
1715
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
16+
#include <memory>
17+
#include <mutex>
18+
#include <unordered_map>
19+
#include "paddle/fluid/framework/op_def_api.h"
1820
#include "paddle/fluid/framework/op_info.h"
21+
1922
namespace paddle {
2023
namespace framework {
2124
namespace ir {
@@ -50,18 +53,17 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
5053
return *this;
5154
}
5255

53-
//! Todo: append the definition.
5456
AttrCompat& AttrCompat::IsLeftDefault() {
5557
const std::string& op_name = op_compat_->Name();
5658
if (!OpInfoMap::Instance().Has(op_name)) {
57-
VLOG(3) << "Op (" << op_name << ") is not registered!";
59+
LOG(WARNING) << "Op (" << op_name << ") is not registered!";
5860
conditions_.emplace_back([](const Attribute& attr) { return false; });
5961
return *this;
6062
}
6163
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
6264
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap();
6365
if (attrs.find(attr_name_) == attrs.end()) {
64-
VLOG(3) << "Op (" << op_name << ") has no default attr:" << attr_name_;
66+
LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_;
6567
conditions_.emplace_back([](const Attribute& attr) { return false; });
6668
} else {
6769
Attribute default_attr = attrs.at(attr_name_);
@@ -77,6 +79,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) {
7779
return true;
7880
}
7981
if (!op_desc.HasAttr(attr_name_)) {
82+
if (!optional_) {
83+
LOG(WARNING) << "The non-optional Attr(" << attr_name_ << ") of Op ("
84+
<< op_compat_->Name() << ") not find ! ";
85+
}
8086
return optional_;
8187
}
8288
const Attribute attr = op_desc.GetAttr(attr_name_);
@@ -149,19 +155,35 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
149155
}
150156

151157
bool OpCompat::Judge(const OpDesc& op_desc) {
158+
if (is_first_judge_) {
159+
is_first_judge_ = false;
160+
const proto::OpDef& op_def = GetOpDef(op_name_);
161+
if (op_def.has_extra()) {
162+
for (const proto::OpDef_AttrDef& attr : op_def.extra().attrs()) {
163+
extra_attrs_.emplace(attr.name());
164+
}
165+
}
166+
}
167+
152168
for (auto& attr_map : op_desc.GetAttrMap()) {
153169
if (attr_compats_.find(attr_map.first) == attr_compats_.end()) {
170+
if (extra_attrs_.find(attr_map.first) != extra_attrs_.end()) {
171+
continue;
172+
}
154173
if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) {
155-
VLOG(3) << "The Attr(" << attr_map.first << ") of Op (" << op_name_
156-
<< ") not reigistered in OpCompat, not equal to default value!";
174+
LOG(WARNING)
175+
<< "The Attr(" << attr_map.first << ") of Op (" << op_name_
176+
<< ") not reigistered in OpCompat, not in extra attribute, not "
177+
"equal to default value!";
157178
return false;
158179
}
159180
}
160181
}
182+
161183
for (auto& attr_compat : attr_compats_) {
162184
if (!attr_compat.second(op_desc)) {
163-
VLOG(3) << " Check the Attr(" << attr_compat.first << ") of Op("
164-
<< op_name_ << ") failed!";
185+
LOG(WARNING) << " Check the Attr(" << attr_compat.first << ") of Op("
186+
<< op_name_ << ") failed!";
165187
return false;
166188
}
167189
}
@@ -170,23 +192,24 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
170192
for (auto& input_desc : inputs_map) {
171193
if (input_compats_.find(input_desc.first) == input_compats_.end()) {
172194
if (!input_desc.second.empty()) {
173-
VLOG(3) << "The Input (" << input_desc.first << ") of Operator ("
174-
<< op_name_ << ") not reigistered in OpCompat!";
195+
LOG(WARNING) << "The Input (" << input_desc.first << ") of Operator ("
196+
<< op_name_ << ") not reigistered in OpCompat!";
175197
return false;
176198
}
177199
}
178200
}
179201
for (auto& input_val : input_compats_) {
180202
if (inputs_map.find(input_val.first) == inputs_map.end()) {
181203
if (!input_val.second.Optional()) {
182-
VLOG(3) << "The No optional Input (" << input_val.first
183-
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
204+
LOG(WARNING) << "The No optional Input (" << input_val.first
205+
<< ") of Operator (" << op_name_
206+
<< ") not find in op_desc!";
184207
return false;
185208
}
186209
} else {
187210
if (!input_val.second(inputs_map.at(input_val.first))) {
188-
VLOG(3) << "The Input (" << input_val.first << ") of Operator ("
189-
<< op_name_ << ") compat check failed!";
211+
LOG(WARNING) << "The Input (" << input_val.first << ") of Operator ("
212+
<< op_name_ << ") compat check failed!";
190213
return false;
191214
}
192215
}
@@ -196,23 +219,24 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
196219
for (auto& output_desc : outputs_map) {
197220
if (output_compats_.find(output_desc.first) == output_compats_.end()) {
198221
if (!output_desc.second.empty()) {
199-
VLOG(3) << "The Output (" << output_desc.first << ") of Operator ("
200-
<< op_name_ << ") not reigistered in OpCompat!";
222+
LOG(WARNING) << "The Output (" << output_desc.first << ") of Operator ("
223+
<< op_name_ << ") not reigistered in OpCompat!";
201224
return false;
202225
}
203226
}
204227
}
205228
for (auto& output_val : output_compats_) {
206229
if (outputs_map.find(output_val.first) == outputs_map.end()) {
207230
if (!output_val.second.Optional()) {
208-
VLOG(3) << "The No optional Output (" << output_val.first
209-
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
231+
LOG(WARNING) << "The No optional Output (" << output_val.first
232+
<< ") of Operator (" << op_name_
233+
<< ") not find in op_desc!";
210234
return false;
211235
}
212236
} else {
213237
if (!output_val.second(outputs_map.at(output_val.first))) {
214-
VLOG(3) << "The Output (" << output_val.first << ") of Operator ("
215-
<< op_name_ << ") compat check failed!";
238+
LOG(WARNING) << "The Output (" << output_val.first << ") of Operator ("
239+
<< op_name_ << ") compat check failed!";
216240
return false;
217241
}
218242
}

paddle/fluid/framework/ir/op_compat_sensible_pass.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class OpCompat {
140140
std::unordered_map<std::string, AttrCompat> attr_compats_;
141141
std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
142142
std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
143+
std::unordered_set<std::string> extra_attrs_;
144+
bool is_first_judge_ = true;
143145
};
144146

145147
/**
@@ -203,6 +205,7 @@ class OpCompatSensiblePass : public Pass {
203205
if (!node_pair.second->IsOp()) continue;
204206
auto op_type = node_pair.second->Op()->Type();
205207
if (!op_compat_judgers_.count(op_type)) {
208+
LOG(WARNING) << op_type << "compat not registered!";
206209
return false;
207210
}
208211
auto& judger = *op_compat_judgers_.at(op_type);

paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ TEST(OpCompatSensiblePass, compatOp) {
2727
compat.AddAttr("in_num_col_dims")
2828
.IsIntIn({1, 2})
2929
.IsNumLE(1)
30-
.IsLeftDefault()
3130
.End()
3231
.AddAttr("activation_type")
3332
.IsStringIn({"tanh", "sigmoid"})
@@ -68,7 +67,7 @@ TEST(OpCompatSensiblePass, compatOp) {
6867
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
6968

7069
EXPECT_STREQ(compat.Name().c_str(), "fc");
71-
EXPECT_FALSE(compat.Judge(fc_op));
70+
EXPECT_TRUE(compat.Judge(fc_op));
7271
}
7372

7473
TEST(OpCompatSensiblePass, compatOpAttribute) {
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if defined _WIN32 || defined __APPLE__
16+
#else
17+
#define _LINUX
18+
#endif
19+
#include "paddle/fluid/framework/op_def_api.h"
20+
#include <fstream>
21+
#include <mutex>
22+
#include <string>
23+
#include <unordered_map>
24+
#ifdef _LINUX
25+
#include <stdio_ext.h>
26+
#include <sys/mman.h>
27+
#include <sys/stat.h>
28+
#endif
29+
#include <google/protobuf/io/zero_copy_stream_impl.h>
30+
#include <google/protobuf/text_format.h>
31+
#include "glog/logging.h"
32+
#include "io/fs.h"
33+
#include "paddle/fluid/framework/op_def.pb.h"
34+
35+
namespace paddle {
36+
namespace framework {
37+
38+
const proto::OpDef& GetOpDef(const std::string& op_name) {
39+
static std::unordered_map<std::string, proto::OpDef> ops_definition;
40+
static std::mutex mtx;
41+
if (ops_definition.find(op_name) == ops_definition.end()) {
42+
std::lock_guard<std::mutex> lk(mtx);
43+
if (ops_definition.find(op_name) == ops_definition.end()) {
44+
proto::OpDef op_def;
45+
std::string op_path = OP_DEF_FOLDER + op_name + ".pbtxt";
46+
int fd = open(op_path.c_str(), O_RDONLY);
47+
if (fd == -1) {
48+
LOG(WARNING) << op_path << " open failed!";
49+
} else {
50+
::google::protobuf::io::FileInputStream* input =
51+
new ::google::protobuf::io::FileInputStream(fd);
52+
if (!::google::protobuf::TextFormat::Parse(input, &op_def)) {
53+
LOG(WARNING) << "Failed to parse " << op_path;
54+
}
55+
delete input;
56+
close(fd);
57+
}
58+
ops_definition.emplace(std::make_pair(op_name, std::move(op_def)));
59+
}
60+
}
61+
return ops_definition.at(op_name);
62+
}
63+
} // namespace framework
64+
} // namespace paddle
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// the folder of pbtxt with op attribute definition
2+
#pragma once
3+
4+
#include "paddle/fluid/framework/op_def.pb.h"
5+
6+
#define OP_DEF_FOLDER "@OP_DEF_FOLDER@"
7+
8+
namespace paddle {
9+
namespace framework {
10+
const proto::OpDef& GetOpDef(const std::string& op_name);
11+
}
12+
}

0 commit comments

Comments
 (0)