File tree Expand file tree Collapse file tree 2 files changed +62
-5
lines changed
paddle/fluid/framework/ir/mkldnn Expand file tree Collapse file tree 2 files changed +62
-5
lines changed Original file line number Diff line number Diff line change @@ -23,7 +23,67 @@ namespace paddle {
2323namespace framework {
2424namespace ir {
2525
26- class Graph ;
26+ ConvConcatReLUFusePass::ConvConcatReLUFusePass () {
27+ AddOpCompat (OpCompat (" conv2d" ))
28+ .AddInput (" Input" )
29+ .IsTensor ()
30+ .End ()
31+ .AddInput (" Filter" )
32+ .IsTensor ()
33+ .End ()
34+ .AddInput (" Bias" )
35+ .IsTensor ()
36+ .IsOptional ()
37+ .End ()
38+ .AddInput (" ResidualData" )
39+ .IsTensor ()
40+ .IsOptional ()
41+ .End ()
42+ .AddOutput (" Output" )
43+ .IsTensor ()
44+ .End ()
45+ .AddAttr (" strides" )
46+ .IsType <std::vector<int >>()
47+ .End ()
48+ .AddAttr (" paddings" )
49+ .IsType <std::vector<int >>()
50+ .End ()
51+ .AddAttr (" padding_algorithm" )
52+ .IsOptional ()
53+ .IsStringIn ({" EXPLICIT" , " SAME" , " VALID" })
54+ .End ()
55+ .AddAttr (" groups" )
56+ .IsNumGE (1 )
57+ .End ()
58+ .AddAttr (" dilations" )
59+ .IsType <std::vector<int >>()
60+ .End ()
61+ .AddAttr (" data_format" )
62+ .IsStringIn ({" NCHW" , " NHWC" , " AnyLayout" })
63+ .End ();
64+
65+ AddOpCompat (OpCompat (" concat" ))
66+ .AddInput (" X" ) // Input("X"): vector<tensors>
67+ .End ()
68+ .AddInput (" AxisTensor" )
69+ .IsTensor ()
70+ .IsOptional ()
71+ .End ()
72+ .AddOutput (" Out" )
73+ .IsTensor ()
74+ .End ()
75+ .AddAttr (" axis" )
76+ .IsNumGE (0 )
77+ .End ();
78+
79+ AddOpCompat (OpCompat (" relu" ))
80+ .AddInput (" X" )
81+ .IsTensor ()
82+ .End ()
83+ .AddOutput (" Out" )
84+ .IsTensor ()
85+ .End ();
86+ }
2787
2888void ConvConcatReLUFusePass::FindConcatWithConvs (
2989 ir::Graph* graph,
Original file line number Diff line number Diff line change 1818#include < unordered_map>
1919
2020#include " paddle/fluid/framework/ir/fuse_pass_base.h"
21- #include " paddle/fluid/framework/ir/graph.h"
22- #include " paddle/fluid/framework/ir/graph_pattern_detector.h"
23- #include " paddle/fluid/framework/ir/pass.h"
2421
2522namespace paddle {
2623namespace framework {
@@ -31,10 +28,10 @@ namespace ir {
3128 * to a:
3229 * (multi ConvReLU) -> Concat -> next_op.
3330 */
34- class Graph ;
3531
3632class ConvConcatReLUFusePass : public FusePassBase {
3733 public:
34+ ConvConcatReLUFusePass ();
3835 virtual ~ConvConcatReLUFusePass () {}
3936
4037 protected:
You can’t perform that action at this time.
0 commit comments