File tree Expand file tree Collapse file tree
paddle/fluid/framework/ir Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -30,6 +30,44 @@ namespace ir {
3030 GET_IR_NODE (reshape2_op); \
3131 GET_IR_NODE (reshape2_out);
3232
33+ ShuffleChannelDetectPass::ShuffleChannelDetectPass () {
34+ AddOpCompat (OpCompat (" reshape2" ))
35+ .AddInput (" X" )
36+ .IsTensor ()
37+ .End ()
38+ .AddInput (" Shape" )
39+ .IsOptional ()
40+ .IsTensor ()
41+ .End ()
42+ .AddInput (" ShapeTensor" )
43+ .IsOptional ()
44+ .IsTensor ()
45+ .End ()
46+ .AddOutput (" XShape" )
47+ .IsTensor ()
48+ .End ()
49+ .AddOutput (" Out" )
50+ .IsTensor ()
51+ .End ()
52+ .AddAttr (" shape" )
53+ .IsType <std::vector<int >>()
54+ .End ();
55+
56+ AddOpCompat (OpCompat (" transpose2" ))
57+ .AddInput (" X" )
58+ .IsTensor ()
59+ .End ()
60+ .AddOutput (" XShape" )
61+ .IsTensor ()
62+ .End ()
63+ .AddOutput (" Out" )
64+ .IsTensor ()
65+ .End ()
66+ .AddAttr (" axis" )
67+ .IsType <std::vector<int >>()
68+ .End ();
69+ }
70+
3371void ShuffleChannelDetectPass::ApplyImpl (ir::Graph* graph) const {
3472 const std::string pattern_name = " shufflechannel_pattern" ;
3573 FusePassBase::Init (pattern_name, graph);
@@ -46,7 +84,10 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
4684 auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
4785 Graph* g) {
4886 GET_NODES;
49-
87+ if (!IsCompat (subgraph, g)) {
88+ LOG (WARNING) << " The Pass in op compat failed." ;
89+ return ;
90+ }
5091 PADDLE_ENFORCE_GT (
5192 subgraph.count (x), 0 ,
5293 platform::errors::NotFound (" Detector did not find input X." ));
Original file line number Diff line number Diff line change @@ -26,6 +26,7 @@ class Graph;
2626
2727class ShuffleChannelDetectPass : public FusePassBase {
2828 public:
29+ ShuffleChannelDetectPass ();
2930 virtual ~ShuffleChannelDetectPass () {}
3031
3132 protected:
You can’t perform that action at this time.
0 commit comments