@@ -34,15 +34,19 @@ namespace patterns {
3434static PDNode* create_emb_vars (PDPattern* pattern, const std::string& name,
3535 const std::string& arg,
3636 bool is_persist = false ) {
37+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
38+ " lookup_table_v2" };
3739 PDNode* node =
38- pattern->NewNode (name)->assert_is_op_input ( " lookup_table " , arg);
40+ pattern->NewNode (name)->assert_is_ops_input (embedding_ops , arg);
3941 if (is_persist) return node->assert_is_persistable_var ();
4042 return node;
4143}
4244static PDNode* create_emb_out_vars (PDPattern* pattern, const std::string& name,
4345 const std::string& arg) {
46+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
47+ " lookup_table_v2" };
4448 PDNode* node = pattern->NewNode (name)
45- ->assert_is_only_output_of_op ( " lookup_table " )
49+ ->assert_is_only_output_of_ops (embedding_ops )
4650 ->assert_is_op_input (" elementwise_add" , arg)
4751 ->AsIntermediate ();
4852 return node;
@@ -56,10 +60,12 @@ void Embedding2Eltwise1Pattern::operator()() {
5660 create_emb_vars (pattern, lookup_table1_w_repr (), " W" , true );
5761 auto * lookup_table2_w =
5862 create_emb_vars (pattern, lookup_table2_w_repr (), " W" , true );
63+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
64+ " lookup_table_v2" };
5965 auto * lookup_table1 =
60- pattern->NewNode (lookup_table1_repr ())->assert_is_op ( " lookup_table " );
66+ pattern->NewNode (lookup_table1_repr ())->assert_is_ops (embedding_ops );
6167 auto * lookup_table2 =
62- pattern->NewNode (lookup_table2_repr ())->assert_is_op ( " lookup_table " );
68+ pattern->NewNode (lookup_table2_repr ())->assert_is_ops (embedding_ops );
6369 auto * lookup_table1_out =
6470 create_emb_out_vars (pattern, lookup_table1_out_repr (), " X" );
6571 auto * lookup_table2_out =
@@ -80,8 +86,10 @@ void Embedding1Eltwise1Pattern::operator()() {
8086 create_emb_vars (pattern, lookup_table1_x_repr (), " Ids" );
8187 auto * lookup_table1_w =
8288 create_emb_vars (pattern, lookup_table1_w_repr (), " W" , true );
89+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
90+ " lookup_table_v2" };
8391 auto * lookup_table1 =
84- pattern->NewNode (lookup_table1_repr ())->assert_is_op ( " lookup_table " );
92+ pattern->NewNode (lookup_table1_repr ())->assert_is_ops (embedding_ops );
8593 auto * lookup_table1_out =
8694 create_emb_out_vars (pattern, lookup_table1_out_repr (), " Y" );
8795 auto * eltwise_add =
@@ -347,4 +355,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
347355 .AddCombination(
348356 paddle::framework::compatible::OpVersionComparatorCombination ()
349357 .EQ(" lookup_table" , 0 )
358+ .LE(" lookup_table_v2" , 1 )
350359 .EQ(" elementweise_add" , 0 ));
0 commit comments