Skip to content

Commit 899792c

Browse files
authored
pass enhance (#33710)
1 parent 9bf00cd commit 899792c

File tree

4 files changed

+110
-0
lines changed

4 files changed

+110
-0
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,60 @@ framework::proto::OpDesc PrepareOpDesc(
4848
return *desc.Proto();
4949
}
5050

51+
ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
52+
AddOpCompat(OpCompat("conv2d"))
53+
.AddInput("Input")
54+
.IsTensor()
55+
.End()
56+
.AddInput("Filter")
57+
.IsTensor()
58+
.End()
59+
.AddInput("ResidualData")
60+
.IsOptional()
61+
.End()
62+
.AddOutput("Output")
63+
.IsTensor()
64+
.End()
65+
.AddAttr("strides")
66+
.End()
67+
.AddAttr("paddings")
68+
.End()
69+
.AddAttr("padding_algorithm")
70+
.IsOptional()
71+
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
72+
.End()
73+
.AddAttr("groups")
74+
.IsNumGE(1)
75+
.End()
76+
.AddAttr("dilations")
77+
.End()
78+
.AddAttr("data_format")
79+
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
80+
.End();
81+
82+
AddOpCompat(OpCompat("elementwise_add"))
83+
.AddInput("X")
84+
.IsTensor()
85+
.End()
86+
.AddInput("Y")
87+
.IsTensor()
88+
.End()
89+
.AddOutput("Out")
90+
.IsTensor()
91+
.End()
92+
.AddAttr("axis")
93+
.IsNumEQ(1)
94+
.End();
95+
96+
AddOpCompat(OpCompat("relu"))
97+
.AddInput("X")
98+
.IsTensor()
99+
.End()
100+
.AddOutput("Out")
101+
.IsTensor()
102+
.End();
103+
}
104+
51105
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
52106
const std::string pattern_name = "conv_elementwise_add_act_fuse";
53107
FusePassBase::Init(pattern_name, graph);
@@ -63,6 +117,10 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
63117

64118
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
65119
Graph* g) {
120+
if (!IsCompat(subgraph, g)) {
121+
LOG(WARNING) << "Pass in op compat failed.";
122+
return;
123+
}
66124
GET_NODES;
67125

68126
auto base_op_desc = *conv_op->Op()->Proto();

paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Graph;
2424

2525
class ConvElementwiseAddActFusePass : public FusePassBase {
2626
public:
27+
ConvElementwiseAddActFusePass();
2728
virtual ~ConvElementwiseAddActFusePass() {}
2829

2930
protected:

paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,52 @@ namespace ir {
2929
GET_IR_NODE(elementwise_add_in_y); \
3030
GET_IR_NODE(elementwise_add_out);
3131

32+
ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() {
33+
AddOpCompat(OpCompat("conv2d"))
34+
.AddInput("Input")
35+
.IsTensor()
36+
.End()
37+
.AddInput("Filter")
38+
.IsTensor()
39+
.End()
40+
.AddInput("ResidualData")
41+
.IsOptional()
42+
.End()
43+
.AddOutput("Output")
44+
.IsTensor()
45+
.End()
46+
.AddAttr("strides")
47+
.End()
48+
.AddAttr("paddings")
49+
.End()
50+
.AddAttr("padding_algorithm")
51+
.IsOptional()
52+
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
53+
.End()
54+
.AddAttr("groups")
55+
.IsNumGE(1)
56+
.End()
57+
.AddAttr("dilations")
58+
.End()
59+
.AddAttr("data_format")
60+
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
61+
.End();
62+
63+
AddOpCompat(OpCompat("elementwise_add"))
64+
.AddInput("X")
65+
.IsTensor()
66+
.End()
67+
.AddInput("Y")
68+
.IsTensor()
69+
.End()
70+
.AddOutput("Out")
71+
.IsTensor()
72+
.End()
73+
.AddAttr("axis")
74+
.IsNumEQ(1)
75+
.End();
76+
}
77+
3278
void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
3379
const std::string pattern_name = "conv_elementwise_add_fuse";
3480
FusePassBase::Init(pattern_name, graph);
@@ -44,6 +90,10 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
4490

4591
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
4692
Graph* g) {
93+
if (!IsCompat(subgraph, g)) {
94+
LOG(WARNING) << "Pass in op compat failed.";
95+
return;
96+
}
4797
GET_NODES;
4898

4999
auto base_op_desc = *conv_op->Op()->Proto();

paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Graph;
2424

2525
class ConvElementwiseAddFusePass : public FusePassBase {
2626
public:
27+
ConvElementwiseAddFusePass();
2728
virtual ~ConvElementwiseAddFusePass() {}
2829

2930
protected:

0 commit comments

Comments
 (0)