Skip to content

Commit 14b7e3c

Browse files
author
Pei Yang
authored
[Paddle-TRT] TRT inference support for BERT/Transformer in paddle 2.0 api (#31744)
* support multihead_matmul_fuse_pass_v3 * fix compile problems * embedding_eltwise_ln pass support lookup_table_v2 * suppoort matmul and matmul_v2 in qkv matmul
1 parent 245252b commit 14b7e3c

File tree

6 files changed

+585
-8
lines changed

6 files changed

+585
-8
lines changed

paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@ namespace patterns {
3434
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
3535
const std::string& arg,
3636
bool is_persist = false) {
37+
std::unordered_set<std::string> embedding_ops{"lookup_table",
38+
"lookup_table_v2"};
3739
PDNode* node =
38-
pattern->NewNode(name)->assert_is_op_input("lookup_table", arg);
40+
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
3941
if (is_persist) return node->assert_is_persistable_var();
4042
return node;
4143
}
4244
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
4345
const std::string& arg) {
46+
std::unordered_set<std::string> embedding_ops{"lookup_table",
47+
"lookup_table_v2"};
4448
PDNode* node = pattern->NewNode(name)
45-
->assert_is_only_output_of_op("lookup_table")
49+
->assert_is_only_output_of_ops(embedding_ops)
4650
->assert_is_op_input("elementwise_add", arg)
4751
->AsIntermediate();
4852
return node;
@@ -56,10 +60,12 @@ void Embedding2Eltwise1Pattern::operator()() {
5660
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
5761
auto* lookup_table2_w =
5862
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
63+
std::unordered_set<std::string> embedding_ops{"lookup_table",
64+
"lookup_table_v2"};
5965
auto* lookup_table1 =
60-
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
66+
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
6167
auto* lookup_table2 =
62-
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table");
68+
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
6369
auto* lookup_table1_out =
6470
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
6571
auto* lookup_table2_out =
@@ -80,8 +86,10 @@ void Embedding1Eltwise1Pattern::operator()() {
8086
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
8187
auto* lookup_table1_w =
8288
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
89+
std::unordered_set<std::string> embedding_ops{"lookup_table",
90+
"lookup_table_v2"};
8391
auto* lookup_table1 =
84-
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
92+
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
8593
auto* lookup_table1_out =
8694
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
8795
auto* eltwise_add =
@@ -347,4 +355,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
347355
.AddCombination(
348356
paddle::framework::compatible::OpVersionComparatorCombination()
349357
.EQ("lookup_table", 0)
358+
.LE("lookup_table_v2", 1)
350359
.EQ("elementweise_add", 0));

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,36 @@ PDNode *PDNode::assert_is_ops_input(
652652
return this;
653653
}
654654

655+
PDNode *PDNode::assert_is_only_input_of_ops(
656+
const std::unordered_set<std::string> &op_types) {
657+
assert_is_var();
658+
asserts_.emplace_back([=](Node *x) {
659+
for (auto *op : x->outputs) {
660+
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
661+
op->inputs.size() == 1) {
662+
return true;
663+
}
664+
}
665+
return false;
666+
});
667+
return this;
668+
}
669+
670+
PDNode *PDNode::assert_is_only_output_of_ops(
671+
const std::unordered_set<std::string> &op_types) {
672+
assert_is_var();
673+
asserts_.emplace_back([=](Node *x) {
674+
for (auto *op : x->inputs) {
675+
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
676+
op->outputs.size() == 1) {
677+
return true;
678+
}
679+
}
680+
return false;
681+
});
682+
return this;
683+
}
684+
655685
bool VarLinksToOp(Node *node, const std::string &op_type) {
656686
for (auto *out : node->outputs) {
657687
if (out->IsOp() && out->Op()->Type() == op_type) {

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ struct PDNode {
145145
const std::unordered_set<std::string>& op_types,
146146
const std::string& argument, int nth);
147147

148+
PDNode* assert_is_only_input_of_ops(
149+
const std::unordered_set<std::string>& op_types);
150+
PDNode* assert_is_only_output_of_ops(
151+
const std::unordered_set<std::string>& op_types);
152+
148153
PDNode* assert_has_n_inputs(size_t n);
149154
PDNode* assert_has_n_outputs(size_t n);
150155

0 commit comments

Comments
 (0)