diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ba1d7379c56d95..a26732926c2c6e 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(fillconstant_elementwisemul_fuse inference) +pass_library(shuffle_channel_detect_pass inference) if(ANAKIN_FOUND) pass_library(simplify_anakin_priorbox_detection_out_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 77f50e914b668e..0dcf064902d1c1 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, } } +void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { + auto reshape1_op = + pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2"); + + auto reshape1_out = pattern->NewNode(reshape1_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2") + ->AsIntermediate(); + + auto reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + + reshape1_op->LinksFrom({reshape1_in}); + reshape1_out->LinksFrom({reshape1_op}); + transpose_op->LinksFrom({reshape1_out}); + transpose_out->LinksFrom({transpose_op}); + reshape2_op->LinksFrom({transpose_out}); + reshape2_out->LinksFrom({reshape2_op}); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 525987e0072cb0..907371b56b06dc 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase { } }; +struct ShuffleChannelPattern : public PatternBase { + ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "shufflechannel_pattern") {} + + void operator()(PDNode* reshape1_in); + + PATTERN_DECL_NODE(reshape1_op); + PATTERN_DECL_NODE(reshape1_out); + + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index fea291c5528a11..ab347b85885fe3 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -79,7 +79,11 @@ const std::vector kAnakinSubgraphPasses({ "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "fc_gru_fuse_pass", // + "graph_viz_pass", // + "shuffle_channel_detect_pass", // + "graph_viz_pass", // "anakin_subgraph_pass", // + "graph_viz_pass", // "fc_gru_fuse_pass", // }); diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 8385e6331d757b..b650225c64a9a3 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -236,6 +236,7 @@ void BindAnalysisConfig(py::module *m) { std::map>(), py::arg("min_subgraph_size") = 6, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, + py::arg("auto_config_layout") = false, py::arg("passes_filter") = std::vector(), py::arg("ops_filter") = std::vector()) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)