Skip to content

Commit 54af52b

Browse files
author
feng_shuai
authored
Shuffle channel detect pass (#33814)
1 parent cc5d4b1 commit 54af52b

2 files changed

Lines changed: 43 additions & 1 deletion

File tree

paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,44 @@ namespace ir {
3030
GET_IR_NODE(reshape2_op); \
3131
GET_IR_NODE(reshape2_out);
3232

33+
ShuffleChannelDetectPass::ShuffleChannelDetectPass() {
34+
AddOpCompat(OpCompat("reshape2"))
35+
.AddInput("X")
36+
.IsTensor()
37+
.End()
38+
.AddInput("Shape")
39+
.IsOptional()
40+
.IsTensor()
41+
.End()
42+
.AddInput("ShapeTensor")
43+
.IsOptional()
44+
.IsTensor()
45+
.End()
46+
.AddOutput("XShape")
47+
.IsTensor()
48+
.End()
49+
.AddOutput("Out")
50+
.IsTensor()
51+
.End()
52+
.AddAttr("shape")
53+
.IsType<std::vector<int>>()
54+
.End();
55+
56+
AddOpCompat(OpCompat("transpose2"))
57+
.AddInput("X")
58+
.IsTensor()
59+
.End()
60+
.AddOutput("XShape")
61+
.IsTensor()
62+
.End()
63+
.AddOutput("Out")
64+
.IsTensor()
65+
.End()
66+
.AddAttr("axis")
67+
.IsType<std::vector<int>>()
68+
.End();
69+
}
70+
3371
void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
3472
const std::string pattern_name = "shufflechannel_pattern";
3573
FusePassBase::Init(pattern_name, graph);
@@ -46,7 +84,10 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
4684
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
4785
Graph* g) {
4886
GET_NODES;
49-
87+
if (!IsCompat(subgraph, g)) {
88+
LOG(WARNING) << "The Pass in op compat failed.";
89+
return;
90+
}
5091
PADDLE_ENFORCE_GT(
5192
subgraph.count(x), 0,
5293
platform::errors::NotFound("Detector did not find input X."));

paddle/fluid/framework/ir/shuffle_channel_detect_pass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class Graph;
2626

2727
class ShuffleChannelDetectPass : public FusePassBase {
2828
public:
29+
ShuffleChannelDetectPass();
2930
virtual ~ShuffleChannelDetectPass() {}
3031

3132
protected:

0 commit comments

Comments
 (0)