@@ -1454,6 +1454,136 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
14541454 return concat_out;
14551455}
14561456
1457+ PDNode *patterns::AnakinDetectionPattern::operator ()(
1458+ std::vector<PDNode *> conv_in, int times) {
1459+ // The times represents the repeat times of the
1460+ // {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
1461+ const int kNumFields = 7 ;
1462+ const int kPriorBoxLocOffset = 1 ;
1463+ const int kReshape1Offset = 2 ;
1464+ const int kReshape1OutOffset = 3 ;
1465+ const int kPriorBoxVarOffset = 4 ;
1466+ const int kReshape2Offset = 5 ;
1467+ const int kReshape2OutOffset = 6 ;
1468+
1469+ const int kBoxCoderThirdInputOffset = times;
1470+ const int kMultiClassSecondInputNmsOffset = times + 1 ;
1471+
1472+ std::vector<PDNode *> nodes;
1473+
1474+ for (int i = 0 ; i < times; i++) {
1475+ nodes.push_back (
1476+ pattern->NewNode (GetNodeName (" prior_box" + std::to_string (i)))
1477+ ->assert_is_op (" density_prior_box" ));
1478+ nodes.push_back (pattern->NewNode (GetNodeName (" box_out" + std::to_string (i)))
1479+ ->assert_is_op_output (" density_prior_box" , " Boxes" )
1480+ ->assert_is_op_input (" reshape2" , " X" )
1481+ ->AsIntermediate ());
1482+ nodes.push_back (
1483+ pattern->NewNode (GetNodeName (" reshape1" + std::to_string (i)))
1484+ ->assert_is_op (" reshape2" ));
1485+
1486+ nodes.push_back (
1487+ pattern->NewNode (GetNodeName (" reshape1_out" + std::to_string (i)))
1488+ ->assert_is_op_output (" reshape2" )
1489+ ->assert_is_op_nth_input (" concat" , " X" , i)
1490+ ->AsIntermediate ());
1491+
1492+ nodes.push_back (
1493+ pattern->NewNode (GetNodeName (" box_var_out" + std::to_string (i)))
1494+ ->assert_is_op_output (" density_prior_box" , " Variances" )
1495+ ->assert_is_op_input (" reshape2" , " X" )
1496+ ->AsIntermediate ());
1497+ nodes.push_back (
1498+ pattern->NewNode (GetNodeName (" reshape2" + std::to_string (i)))
1499+ ->assert_is_op (" reshape2" ));
1500+
1501+ nodes.push_back (
1502+ pattern->NewNode (GetNodeName (" reshape2_out" + std::to_string (i)))
1503+ ->assert_is_op_output (" reshape2" )
1504+ ->assert_is_op_nth_input (" concat" , " X" , i)
1505+ ->AsIntermediate ());
1506+ }
1507+
1508+ auto concat_op1 = pattern->NewNode (GetNodeName (" concat1" ))
1509+ ->assert_is_op (" concat" )
1510+ ->assert_op_has_n_inputs (" concat" , times);
1511+ auto concat_out1 = pattern->NewNode (GetNodeName (" concat1_out" ))
1512+ ->assert_is_op_output (" concat" )
1513+ ->AsIntermediate ();
1514+
1515+ auto concat_op2 = pattern->NewNode (GetNodeName (" concat2" ))
1516+ ->assert_is_op (" concat" )
1517+ ->assert_op_has_n_inputs (" concat" , times);
1518+ auto concat_out2 = pattern->NewNode (GetNodeName (" concat2_out" ))
1519+ ->assert_is_op_output (" concat" )
1520+ ->AsIntermediate ();
1521+
1522+ auto box_coder_op = pattern->NewNode (GetNodeName (" box_coder" ))
1523+ ->assert_is_op (" box_coder" )
1524+ ->assert_op_has_n_inputs (" box_coder" , 3 );
1525+
1526+ auto box_coder_out = pattern->NewNode (GetNodeName (" box_coder_out" ))
1527+ ->assert_is_op_output (" box_coder" )
1528+ ->AsIntermediate ();
1529+
1530+ auto multiclass_nms_op = pattern->NewNode (GetNodeName (" multiclass_nms" ))
1531+ ->assert_is_op (" multiclass_nms" )
1532+ ->assert_op_has_n_inputs (" multiclass_nms" , 2 );
1533+
1534+ auto multiclass_nms_out = pattern->NewNode (GetNodeName (" multiclass_nms_out" ))
1535+ ->assert_is_op_output (" multiclass_nms" )
1536+ ->AsOutput ();
1537+
1538+ std::vector<PDNode *> reshape1_outs;
1539+ std::vector<PDNode *> reshape2_outs;
1540+
1541+ for (int i = 0 ; i < times; i++) {
1542+ conv_in[i]->AsInput ();
1543+ // prior_box
1544+ nodes[i * kNumFields ]->LinksFrom ({conv_in[i]});
1545+ // prior_box box out
1546+ nodes[i * kNumFields + kPriorBoxLocOffset ]->LinksFrom (
1547+ {nodes[i * kNumFields ]});
1548+ // reshape
1549+ nodes[i * kNumFields + kReshape1Offset ]->LinksFrom (
1550+ {nodes[i * kNumFields + kPriorBoxLocOffset ]});
1551+ // reshape_out
1552+ nodes[i * kNumFields + kReshape1OutOffset ]->LinksFrom (
1553+ {nodes[i * kNumFields + kReshape1Offset ]});
1554+
1555+ nodes[i * kNumFields + kPriorBoxVarOffset ]->LinksFrom (
1556+ {nodes[i * kNumFields ]});
1557+ // reshape
1558+ nodes[i * kNumFields + kReshape2Offset ]->LinksFrom (
1559+ {nodes[i * kNumFields + kPriorBoxVarOffset ]});
1560+ // reshape_out
1561+ nodes[i * kNumFields + kReshape2OutOffset ]->LinksFrom (
1562+ {nodes[i * kNumFields + kReshape2Offset ]});
1563+
1564+ reshape1_outs.push_back (nodes[i * kNumFields + kReshape1OutOffset ]);
1565+ reshape2_outs.push_back (nodes[i * kNumFields + kReshape2OutOffset ]);
1566+ }
1567+
1568+ concat_op1->LinksFrom (reshape1_outs);
1569+ concat_op2->LinksFrom (reshape2_outs);
1570+ concat_out1->LinksFrom ({concat_op1});
1571+ concat_out2->LinksFrom ({concat_op2});
1572+
1573+ conv_in[kBoxCoderThirdInputOffset ]->AsInput ();
1574+ conv_in[kMultiClassSecondInputNmsOffset ]->AsInput ();
1575+
1576+ box_coder_op->LinksFrom (
1577+ {concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset ]});
1578+ box_coder_out->LinksFrom ({box_coder_op});
1579+
1580+ multiclass_nms_op
1581+ ->LinksFrom ({box_coder_out, conv_in[kMultiClassSecondInputNmsOffset ]})
1582+ .LinksTo ({multiclass_nms_out});
1583+
1584+ return multiclass_nms_out;
1585+ }
1586+
14571587} // namespace ir
14581588} // namespace framework
14591589} // namespace paddle
0 commit comments