Skip to content

Commit a1d200a

Browse files
committed
cherry-pick from feature/anakin-engine: Anakin support facebox PaddlePaddle#16111
1 parent a32d420 commit a1d200a

25 files changed

Lines changed: 765 additions & 28 deletions

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
7171
pass_library(identity_scale_op_clean_pass base)
7272
pass_library(sync_batch_norm_pass base)
7373
pass_library(runtime_context_cache_pass base)
74+
pass_library(simplify_anakin_detection_pattern_pass inference)
7475

7576
# There may be many transpose-flatten structures in a model, and the output of
7677
# these structures will be used as inputs to the concat Op. This pattern will
@@ -81,6 +82,10 @@ foreach (index RANGE 3 6)
8182
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
8283
endforeach()
8384

85+
foreach (index RANGE 3 6)
86+
file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n")
87+
endforeach()
88+
8489
if(WITH_MKLDNN)
8590
pass_library(mkldnn_placement_pass base mkldnn)
8691
pass_library(depthwise_conv_mkldnn_pass base mkldnn)

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,136 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
14541454
return concat_out;
14551455
}
14561456

1457+
PDNode *patterns::AnakinDetectionPattern::operator()(
1458+
std::vector<PDNode *> conv_in, int times) {
1459+
// The times represents the repeat times of the
1460+
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
1461+
const int kNumFields = 7;
1462+
const int kPriorBoxLocOffset = 1;
1463+
const int kReshape1Offset = 2;
1464+
const int kReshape1OutOffset = 3;
1465+
const int kPriorBoxVarOffset = 4;
1466+
const int kReshape2Offset = 5;
1467+
const int kReshape2OutOffset = 6;
1468+
1469+
const int kBoxCoderThirdInputOffset = times;
1470+
const int kMultiClassSecondInputNmsOffset = times + 1;
1471+
1472+
std::vector<PDNode *> nodes;
1473+
1474+
for (int i = 0; i < times; i++) {
1475+
nodes.push_back(
1476+
pattern->NewNode(GetNodeName("prior_box" + std::to_string(i)))
1477+
->assert_is_op("density_prior_box"));
1478+
nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i)))
1479+
->assert_is_op_output("density_prior_box", "Boxes")
1480+
->assert_is_op_input("reshape2", "X")
1481+
->AsIntermediate());
1482+
nodes.push_back(
1483+
pattern->NewNode(GetNodeName("reshape1" + std::to_string(i)))
1484+
->assert_is_op("reshape2"));
1485+
1486+
nodes.push_back(
1487+
pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i)))
1488+
->assert_is_op_output("reshape2")
1489+
->assert_is_op_nth_input("concat", "X", i)
1490+
->AsIntermediate());
1491+
1492+
nodes.push_back(
1493+
pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i)))
1494+
->assert_is_op_output("density_prior_box", "Variances")
1495+
->assert_is_op_input("reshape2", "X")
1496+
->AsIntermediate());
1497+
nodes.push_back(
1498+
pattern->NewNode(GetNodeName("reshape2" + std::to_string(i)))
1499+
->assert_is_op("reshape2"));
1500+
1501+
nodes.push_back(
1502+
pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i)))
1503+
->assert_is_op_output("reshape2")
1504+
->assert_is_op_nth_input("concat", "X", i)
1505+
->AsIntermediate());
1506+
}
1507+
1508+
auto concat_op1 = pattern->NewNode(GetNodeName("concat1"))
1509+
->assert_is_op("concat")
1510+
->assert_op_has_n_inputs("concat", times);
1511+
auto concat_out1 = pattern->NewNode(GetNodeName("concat1_out"))
1512+
->assert_is_op_output("concat")
1513+
->AsIntermediate();
1514+
1515+
auto concat_op2 = pattern->NewNode(GetNodeName("concat2"))
1516+
->assert_is_op("concat")
1517+
->assert_op_has_n_inputs("concat", times);
1518+
auto concat_out2 = pattern->NewNode(GetNodeName("concat2_out"))
1519+
->assert_is_op_output("concat")
1520+
->AsIntermediate();
1521+
1522+
auto box_coder_op = pattern->NewNode(GetNodeName("box_coder"))
1523+
->assert_is_op("box_coder")
1524+
->assert_op_has_n_inputs("box_coder", 3);
1525+
1526+
auto box_coder_out = pattern->NewNode(GetNodeName("box_coder_out"))
1527+
->assert_is_op_output("box_coder")
1528+
->AsIntermediate();
1529+
1530+
auto multiclass_nms_op = pattern->NewNode(GetNodeName("multiclass_nms"))
1531+
->assert_is_op("multiclass_nms")
1532+
->assert_op_has_n_inputs("multiclass_nms", 2);
1533+
1534+
auto multiclass_nms_out = pattern->NewNode(GetNodeName("multiclass_nms_out"))
1535+
->assert_is_op_output("multiclass_nms")
1536+
->AsOutput();
1537+
1538+
std::vector<PDNode *> reshape1_outs;
1539+
std::vector<PDNode *> reshape2_outs;
1540+
1541+
for (int i = 0; i < times; i++) {
1542+
conv_in[i]->AsInput();
1543+
// prior_box
1544+
nodes[i * kNumFields]->LinksFrom({conv_in[i]});
1545+
// prior_box box out
1546+
nodes[i * kNumFields + kPriorBoxLocOffset]->LinksFrom(
1547+
{nodes[i * kNumFields]});
1548+
// reshape
1549+
nodes[i * kNumFields + kReshape1Offset]->LinksFrom(
1550+
{nodes[i * kNumFields + kPriorBoxLocOffset]});
1551+
// reshape_out
1552+
nodes[i * kNumFields + kReshape1OutOffset]->LinksFrom(
1553+
{nodes[i * kNumFields + kReshape1Offset]});
1554+
1555+
nodes[i * kNumFields + kPriorBoxVarOffset]->LinksFrom(
1556+
{nodes[i * kNumFields]});
1557+
// reshape
1558+
nodes[i * kNumFields + kReshape2Offset]->LinksFrom(
1559+
{nodes[i * kNumFields + kPriorBoxVarOffset]});
1560+
// reshape_out
1561+
nodes[i * kNumFields + kReshape2OutOffset]->LinksFrom(
1562+
{nodes[i * kNumFields + kReshape2Offset]});
1563+
1564+
reshape1_outs.push_back(nodes[i * kNumFields + kReshape1OutOffset]);
1565+
reshape2_outs.push_back(nodes[i * kNumFields + kReshape2OutOffset]);
1566+
}
1567+
1568+
concat_op1->LinksFrom(reshape1_outs);
1569+
concat_op2->LinksFrom(reshape2_outs);
1570+
concat_out1->LinksFrom({concat_op1});
1571+
concat_out2->LinksFrom({concat_op2});
1572+
1573+
conv_in[kBoxCoderThirdInputOffset]->AsInput();
1574+
conv_in[kMultiClassSecondInputNmsOffset]->AsInput();
1575+
1576+
box_coder_op->LinksFrom(
1577+
{concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset]});
1578+
box_coder_out->LinksFrom({box_coder_op});
1579+
1580+
multiclass_nms_op
1581+
->LinksFrom({box_coder_out, conv_in[kMultiClassSecondInputNmsOffset]})
1582+
.LinksTo({multiclass_nms_out});
1583+
1584+
return multiclass_nms_out;
1585+
}
1586+
14571587
} // namespace ir
14581588
} // namespace framework
14591589
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,21 @@ struct TransposeFlattenConcat : public PatternBase {
841841
}
842842
};
843843

844+
struct AnakinDetectionPattern : public PatternBase {
845+
AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope)
846+
: PatternBase(pattern, name_scope, "anakin_detect_pattern") {}
847+
848+
PDNode* operator()(std::vector<PDNode*> conv_inputs, int times);
849+
850+
std::string GetNodeName(const std::string& op_type) {
851+
return PDNodeName(name_scope_, repr_, id_, op_type);
852+
}
853+
854+
PDNode* GetPDNode(const std::string& op_type) {
855+
return pattern->RetrieveNode(GetNodeName(op_type));
856+
}
857+
};
858+
844859
} // namespace patterns
845860

846861
// Link two ir::Nodes from each other.

0 commit comments

Comments
 (0)