Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(conv_affine_channel_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(simplify_anakin_detection_pattern_pass inference)

# There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will
Expand All @@ -76,6 +77,10 @@ foreach (index RANGE 3 6)
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
endforeach()

foreach (index RANGE 3 6)
file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n")
endforeach()

if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn)
pass_library(depthwise_conv_mkldnn_pass base mkldnn)
Expand Down
130 changes: 130 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,136 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out;
}

PDNode *patterns::AnakinDetectionPattern::operator()(
std::vector<PDNode *> conv_in, int times) {
// The times represents the repeat times of the
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
const int kNumFields = 7;
const int kPriorBoxLocOffset = 1;
const int kReshape1Offset = 2;
const int kReshape1OutOffset = 3;
const int kPriorBoxVarOffset = 4;
const int kReshape2Offset = 5;
const int kReshape2OutOffset = 6;

const int kBoxCoderThirdInputOffset = times;
const int kMultiClassSecondInputNmsOffset = times + 1;

std::vector<PDNode *> nodes;

for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("prior_box" + std::to_string(i)))
->assert_is_op("density_prior_box"));
nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Boxes")
->assert_is_op_input("reshape2", "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("reshape1" + std::to_string(i)))
->assert_is_op("reshape2"));

nodes.push_back(
pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i)))
->assert_is_op_output("reshape2")
->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate());

nodes.push_back(
pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Variances")
->assert_is_op_input("reshape2", "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("reshape2" + std::to_string(i)))
->assert_is_op("reshape2"));

nodes.push_back(
pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i)))
->assert_is_op_output("reshape2")
->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate());
}

auto concat_op1 = pattern->NewNode(GetNodeName("concat1"))
->assert_is_op("concat")
->assert_op_has_n_inputs("concat", times);
auto concat_out1 = pattern->NewNode(GetNodeName("concat1_out"))
->assert_is_op_output("concat")
->AsIntermediate();

auto concat_op2 = pattern->NewNode(GetNodeName("concat2"))
->assert_is_op("concat")
->assert_op_has_n_inputs("concat", times);
auto concat_out2 = pattern->NewNode(GetNodeName("concat2_out"))
->assert_is_op_output("concat")
->AsIntermediate();

auto box_coder_op = pattern->NewNode(GetNodeName("box_coder"))
->assert_is_op("box_coder")
->assert_op_has_n_inputs("box_coder", 3);

auto box_coder_out = pattern->NewNode(GetNodeName("box_coder_out"))
->assert_is_op_output("box_coder")
->AsIntermediate();

auto multiclass_nms_op = pattern->NewNode(GetNodeName("multiclass_nms"))
->assert_is_op("multiclass_nms")
->assert_op_has_n_inputs("multiclass_nms", 2);

auto multiclass_nms_out = pattern->NewNode(GetNodeName("multiclass_nms_out"))
->assert_is_op_output("multiclass_nms")
->AsOutput();

std::vector<PDNode *> reshape1_outs;
std::vector<PDNode *> reshape2_outs;

for (int i = 0; i < times; i++) {
conv_in[i]->AsInput();
// prior_box
nodes[i * kNumFields]->LinksFrom({conv_in[i]});
// prior_box box out
nodes[i * kNumFields + kPriorBoxLocOffset]->LinksFrom(
{nodes[i * kNumFields]});
// reshape
nodes[i * kNumFields + kReshape1Offset]->LinksFrom(
{nodes[i * kNumFields + kPriorBoxLocOffset]});
// reshape_out
nodes[i * kNumFields + kReshape1OutOffset]->LinksFrom(
{nodes[i * kNumFields + kReshape1Offset]});

nodes[i * kNumFields + kPriorBoxVarOffset]->LinksFrom(
{nodes[i * kNumFields]});
// reshape
nodes[i * kNumFields + kReshape2Offset]->LinksFrom(
{nodes[i * kNumFields + kPriorBoxVarOffset]});
// reshape_out
nodes[i * kNumFields + kReshape2OutOffset]->LinksFrom(
{nodes[i * kNumFields + kReshape2Offset]});

reshape1_outs.push_back(nodes[i * kNumFields + kReshape1OutOffset]);
reshape2_outs.push_back(nodes[i * kNumFields + kReshape2OutOffset]);
}

concat_op1->LinksFrom(reshape1_outs);
concat_op2->LinksFrom(reshape2_outs);
concat_out1->LinksFrom({concat_op1});
concat_out2->LinksFrom({concat_op2});

conv_in[kBoxCoderThirdInputOffset]->AsInput();
conv_in[kMultiClassSecondInputNmsOffset]->AsInput();

box_coder_op->LinksFrom(
{concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset]});
box_coder_out->LinksFrom({box_coder_op});

multiclass_nms_op
->LinksFrom({box_coder_out, conv_in[kMultiClassSecondInputNmsOffset]})
.LinksTo({multiclass_nms_out});

return multiclass_nms_out;
}

} // namespace ir
} // namespace framework
} // namespace paddle
18 changes: 18 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
#include <gtest/gtest_prod.h>
#endif

#include <memory>
#include <numeric>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
Expand Down Expand Up @@ -781,6 +784,21 @@ struct TransposeFlattenConcat : public PatternBase {
}
};

struct AnakinDetectionPattern : public PatternBase {
AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "anakin_detect_pattern") {}

PDNode* operator()(std::vector<PDNode*> conv_inputs, int times);

std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
}

PDNode* GetPDNode(const std::string& op_type) {
return pattern->RetrieveNode(GetNodeName(op_type));
}
};

} // namespace patterns

// Link two ir::Nodes from each other.
Expand Down
Loading