@@ -49,6 +49,11 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
4949 auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
5050 Graph* g) {
5151 VLOG (4 ) << " handle " + conv_type () + " +" + activation_type () + " fuse" ;
52+
53+ if (!IsCompat (subgraph, g)) {
54+ LOG (WARNING) << " Pass op compat failed." ;
55+ return ;
56+ }
5257 GET_IR_NODE_FROM_SUBGRAPH (conv_weight, conv_weight,
5358 conv_activation_pattern); // Filter
5459 GET_IR_NODE_FROM_SUBGRAPH (conv_out, conv_out,
@@ -97,6 +102,113 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
97102 AddStatis (found_conv_activation_count);
98103}
99104
105+ ConvActivationFusePass::ConvActivationFusePass () {
106+ AddOpCompat (OpCompat (" conv2d" ))
107+ .AddInput (" Input" )
108+ .IsTensor ()
109+ .End ()
110+ .AddInput (" Filter" )
111+ .IsTensor ()
112+ .End ()
113+ .AddInput (" Bias" )
114+ .IsOptional ()
115+ .IsTensor ()
116+ .End ()
117+ .AddOutput (" Output" )
118+ .IsTensor ()
119+ .End ()
120+ .AddAttr (" strides" )
121+ .IsType <std::vector<int >>()
122+ .End ()
123+ .AddAttr (" paddings" )
124+ .IsType <std::vector<int >>()
125+ .End ()
126+ // IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this
127+ // attribute
128+ .AddAttr (" padding_algorithm" )
129+ .IsOptional ()
130+ .IsStringIn ({" EXPLICIT" , " SAME" , " VALID" })
131+ .End ()
132+ .AddAttr (" groups" )
133+ .IsNumGE (1 )
134+ .End ()
135+ .AddAttr (" dilations" )
136+ .IsType <std::vector<int >>()
137+ .End ()
138+ // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
139+ .AddAttr (" data_format" )
140+ .IsOptional ()
141+ .IsStringIn ({" NHWC" , " NCHW" , " AnyLayout" })
142+ .End ();
143+
144+ AddOpCompat (OpCompat (" relu" ))
145+ .AddInput (" X" )
146+ .IsTensor ()
147+ .End ()
148+ .AddOutput (" Out" )
149+ .IsTensor ()
150+ .End ();
151+ }
152+ Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass () {
153+ AddOpCompat (OpCompat (" leaky_relu" ))
154+ .AddInput (" X" )
155+ .IsTensor ()
156+ .End ()
157+ .AddOutput (" Out" )
158+ .IsTensor ()
159+ .End ()
160+ // float, default=0.02
161+ .AddAttr (" alpha" )
162+ .IsType <float >()
163+ .End ();
164+ }
165+ Conv2DReLU6FusePass::Conv2DReLU6FusePass () {
166+ AddOpCompat (OpCompat (" relu6" ))
167+ .AddInput (" X" )
168+ .IsTensor ()
169+ .End ()
170+ .AddOutput (" Out" )
171+ .IsTensor ()
172+ .End ()
173+ // default = 6.0f
174+ .AddAttr (" threshold" )
175+ .IsType <float >()
176+ .End ();
177+ }
178+ Conv2DSwishFusePass::Conv2DSwishFusePass () {
179+ AddOpCompat (OpCompat (" swish" ))
180+ .AddInput (" X" )
181+ .IsTensor ()
182+ .End ()
183+ .AddOutput (" Out" )
184+ .IsTensor ()
185+ .End ();
186+ }
187+ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass () {
188+ AddOpCompat (OpCompat (" hard_swish" ))
189+ .AddInput (" X" )
190+ .IsTensor ()
191+ .End ()
192+ .AddOutput (" Out" )
193+ .IsTensor ()
194+ .End ()
195+ // float, optional, default=6.0
196+ .AddAttr (" threshold" )
197+ .IsOptional ()
198+ .IsType <float >()
199+ .End ()
200+ // float, optional, default=6.0
201+ .AddAttr (" scale" )
202+ .IsOptional ()
203+ .IsType <float >()
204+ .End ()
205+ // float, optional, default=3.0
206+ .AddAttr (" offset" )
207+ .IsOptional ()
208+ .IsType <float >()
209+ .End ();
210+ }
211+
100212} // namespace ir
101213} // namespace framework
102214} // namespace paddle
0 commit comments