Skip to content

Commit cc5d4b1

Browse files
author
feng_shuai
authored
Conv relu mkldnn fuse pass (#33664)
1 parent 79e75bc commit cc5d4b1

8 files changed

Lines changed: 172 additions & 2 deletions

File tree

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
4949
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
5050
Graph* g) {
5151
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
52+
53+
if (!IsCompat(subgraph, g)) {
54+
LOG(WARNING) << "Pass op compat failed.";
55+
return;
56+
}
5257
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
5358
conv_activation_pattern); // Filter
5459
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out,
@@ -97,6 +102,113 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
97102
AddStatis(found_conv_activation_count);
98103
}
99104

105+
ConvActivationFusePass::ConvActivationFusePass() {
106+
AddOpCompat(OpCompat("conv2d"))
107+
.AddInput("Input")
108+
.IsTensor()
109+
.End()
110+
.AddInput("Filter")
111+
.IsTensor()
112+
.End()
113+
.AddInput("Bias")
114+
.IsOptional()
115+
.IsTensor()
116+
.End()
117+
.AddOutput("Output")
118+
.IsTensor()
119+
.End()
120+
.AddAttr("strides")
121+
.IsType<std::vector<int>>()
122+
.End()
123+
.AddAttr("paddings")
124+
.IsType<std::vector<int>>()
125+
.End()
126+
// IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this
127+
// attribute
128+
.AddAttr("padding_algorithm")
129+
.IsOptional()
130+
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
131+
.End()
132+
.AddAttr("groups")
133+
.IsNumGE(1)
134+
.End()
135+
.AddAttr("dilations")
136+
.IsType<std::vector<int>>()
137+
.End()
138+
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
139+
.AddAttr("data_format")
140+
.IsOptional()
141+
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
142+
.End();
143+
144+
AddOpCompat(OpCompat("relu"))
145+
.AddInput("X")
146+
.IsTensor()
147+
.End()
148+
.AddOutput("Out")
149+
.IsTensor()
150+
.End();
151+
}
152+
Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() {
153+
AddOpCompat(OpCompat("leaky_relu"))
154+
.AddInput("X")
155+
.IsTensor()
156+
.End()
157+
.AddOutput("Out")
158+
.IsTensor()
159+
.End()
160+
// float, default=0.02
161+
.AddAttr("alpha")
162+
.IsType<float>()
163+
.End();
164+
}
165+
Conv2DReLU6FusePass::Conv2DReLU6FusePass() {
166+
AddOpCompat(OpCompat("relu6"))
167+
.AddInput("X")
168+
.IsTensor()
169+
.End()
170+
.AddOutput("Out")
171+
.IsTensor()
172+
.End()
173+
// default = 6.0f
174+
.AddAttr("threshold")
175+
.IsType<float>()
176+
.End();
177+
}
178+
Conv2DSwishFusePass::Conv2DSwishFusePass() {
179+
AddOpCompat(OpCompat("swish"))
180+
.AddInput("X")
181+
.IsTensor()
182+
.End()
183+
.AddOutput("Out")
184+
.IsTensor()
185+
.End();
186+
}
187+
Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
188+
AddOpCompat(OpCompat("hard_swish"))
189+
.AddInput("X")
190+
.IsTensor()
191+
.End()
192+
.AddOutput("Out")
193+
.IsTensor()
194+
.End()
195+
// float, optional, default=6.0
196+
.AddAttr("threshold")
197+
.IsOptional()
198+
.IsType<float>()
199+
.End()
200+
// float, optional, default=6.0
201+
.AddAttr("scale")
202+
.IsOptional()
203+
.IsType<float>()
204+
.End()
205+
// float, optional, default=3.0
206+
.AddAttr("offset")
207+
.IsOptional()
208+
.IsType<float>()
209+
.End();
210+
}
211+
100212
} // namespace ir
101213
} // namespace framework
102214
} // namespace paddle

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Graph;
3131

3232
class ConvActivationFusePass : public FusePassBase {
3333
public:
34+
ConvActivationFusePass();
3435
virtual ~ConvActivationFusePass() {}
3536
virtual std::string conv_type() const { return "conv2d"; }
3637
virtual std::string activation_type() const { return "relu"; }
@@ -44,27 +45,31 @@ class ConvActivationFusePass : public FusePassBase {
4445
*/
4546
class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
4647
public:
48+
Conv2DLeakyReLUFusePass();
4749
std::string activation_type() const { return "leaky_relu"; }
4850
};
4951
/*
5052
* Fuse Conv and BoundedReLU class
5153
*/
5254
class Conv2DReLU6FusePass : public ConvActivationFusePass {
5355
public:
56+
Conv2DReLU6FusePass();
5457
std::string activation_type() const { return "relu6"; }
5558
};
5659
/*
5760
* Fuse Conv and Swish class
5861
*/
5962
class Conv2DSwishFusePass : public ConvActivationFusePass {
6063
public:
64+
Conv2DSwishFusePass();
6165
std::string activation_type() const { return "swish"; }
6266
};
6367
/*
6468
* Fuse Conv and HardSwish class
6569
*/
6670
class Conv2DHardSwishFusePass : public ConvActivationFusePass {
6771
public:
72+
Conv2DHardSwishFusePass();
6873
std::string activation_type() const { return "hard_swish"; }
6974
};
7075
} // namespace ir

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
1616

1717
#include <gtest/gtest.h>
18+
#include <vector>
1819
#include "paddle/fluid/framework/op_proto_maker.h"
1920

2021
namespace paddle {
@@ -30,9 +31,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
3031
op->SetAttr("name", name);
3132
if (type == "conv2d") {
3233
op->SetAttr("use_mkldnn", use_mkldnn);
34+
op->SetAttr("groups", 1);
35+
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
36+
op->SetAttr("data_format", std::string("NCHW"));
37+
op->SetAttr("strides", std::vector<int>({1, 1}));
38+
op->SetAttr("dilations", std::vector<int>({1, 1}));
39+
op->SetAttr("paddings", std::vector<int>({0, 0}));
3340
op->SetInput("Input", {inputs[0]});
3441
op->SetInput("Filter", {inputs[1]});
3542
op->SetInput("Bias", {inputs[2]});
43+
op->SetOutput("Output", outputs);
3644
} else if (is_activation) {
3745
op->SetAttr("use_mkldnn", use_mkldnn);
3846
op->SetInput("X", inputs);
@@ -43,8 +51,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
4351
} else if (type == "swish") {
4452
op->SetAttr("beta", 1.0f);
4553
}
54+
op->SetOutput("Out", outputs);
4655
}
47-
op->SetOutput("Out", outputs);
56+
4857
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
4958
static_cast<int>(OpRole::kForward));
5059
}

paddle/fluid/operators/compat/hard_swish.pbtxt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ extra {
2424
name: "op_role"
2525
type: INT
2626
}
27+
attrs {
28+
name: "use_mkldnn"
29+
type: BOOLEAN
30+
}
31+
attrs {
32+
name: "name"
33+
type: STRING
34+
}
35+
attrs {
36+
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
37+
type: BOOLEAN
38+
}
2739
attrs {
2840
name: "op_role_var"
2941
type: STRINGS

paddle/fluid/operators/compat/leaky_relu.pbtxt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ extra {
1616
name: "use_mkldnn"
1717
type: BOOLEAN
1818
}
19+
attrs {
20+
name: "name"
21+
type: STRING
22+
}
23+
attrs {
24+
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
25+
type: BOOLEAN
26+
}
27+
attrs {
28+
name: "is_test"
29+
type: BOOLEAN
30+
}
1931
attrs {
2032
name: "op_role"
2133
type: INT

paddle/fluid/operators/compat/relu.pbtxt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,8 @@ extra {
5252
name: "is_test"
5353
type: BOOLEAN
5454
}
55+
attrs {
56+
name: "name"
57+
type: STRINGS
58+
}
5559
}

paddle/fluid/operators/compat/relu6.pbtxt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,28 @@ def {
66
outputs {
77
name: "Out"
88
}
9+
attrs {
10+
name: "threshold"
11+
type: FLOAT
12+
}
913
}
1014
extra {
1115
attrs {
12-
name: "threshold"
16+
name: "name"
17+
type: STRING
18+
}
19+
attrs {
20+
name: "is_test"
1321
type: FLOAT
1422
}
1523
attrs {
1624
name: "use_mkldnn"
1725
type: BOOLEAN
1826
}
27+
attrs {
28+
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
29+
type: BOOLEAN
30+
}
1931
attrs {
2032
name: "op_role"
2133
type: INT

paddle/fluid/operators/compat/swish.pbtxt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ extra {
1212
name: "beta"
1313
type: FLOAT
1414
}
15+
attrs {
16+
name: "name"
17+
type: STRING
18+
}
1519
attrs {
1620
name: "use_mkldnn"
1721
type: BOOLEAN

0 commit comments

Comments
 (0)