Skip to content

Commit 33edb62

Browse files
authored
pass_enhance_conv_concat_relu_mkldnn (#33867)
1 parent 7c4e515 commit 33edb62

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,67 @@ namespace paddle {
2323
namespace framework {
2424
namespace ir {
2525

26-
class Graph;
26+
ConvConcatReLUFusePass::ConvConcatReLUFusePass() {
27+
AddOpCompat(OpCompat("conv2d"))
28+
.AddInput("Input")
29+
.IsTensor()
30+
.End()
31+
.AddInput("Filter")
32+
.IsTensor()
33+
.End()
34+
.AddInput("Bias")
35+
.IsTensor()
36+
.IsOptional()
37+
.End()
38+
.AddInput("ResidualData")
39+
.IsTensor()
40+
.IsOptional()
41+
.End()
42+
.AddOutput("Output")
43+
.IsTensor()
44+
.End()
45+
.AddAttr("strides")
46+
.IsType<std::vector<int>>()
47+
.End()
48+
.AddAttr("paddings")
49+
.IsType<std::vector<int>>()
50+
.End()
51+
.AddAttr("padding_algorithm")
52+
.IsOptional()
53+
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
54+
.End()
55+
.AddAttr("groups")
56+
.IsNumGE(1)
57+
.End()
58+
.AddAttr("dilations")
59+
.IsType<std::vector<int>>()
60+
.End()
61+
.AddAttr("data_format")
62+
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
63+
.End();
64+
65+
AddOpCompat(OpCompat("concat"))
66+
.AddInput("X") // Input("X"): vector<tensors>
67+
.End()
68+
.AddInput("AxisTensor")
69+
.IsTensor()
70+
.IsOptional()
71+
.End()
72+
.AddOutput("Out")
73+
.IsTensor()
74+
.End()
75+
.AddAttr("axis")
76+
.IsNumGE(0)
77+
.End();
78+
79+
AddOpCompat(OpCompat("relu"))
80+
.AddInput("X")
81+
.IsTensor()
82+
.End()
83+
.AddOutput("Out")
84+
.IsTensor()
85+
.End();
86+
}
2787

2888
void ConvConcatReLUFusePass::FindConcatWithConvs(
2989
ir::Graph* graph,

paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
#include <unordered_map>
1919

2020
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
21-
#include "paddle/fluid/framework/ir/graph.h"
22-
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
23-
#include "paddle/fluid/framework/ir/pass.h"
2421

2522
namespace paddle {
2623
namespace framework {
@@ -31,10 +28,10 @@ namespace ir {
3128
* to a:
3229
* (multi ConvReLU) -> Concat -> next_op.
3330
*/
34-
class Graph;
3531

3632
class ConvConcatReLUFusePass : public FusePassBase {
3733
public:
34+
ConvConcatReLUFusePass();
3835
virtual ~ConvConcatReLUFusePass() {}
3936

4037
protected:

0 commit comments

Comments
 (0)