@@ -48,18 +48,22 @@ class XPUSingleEncoderFuser : public FuseBase {
4848 const std::string& input_pos = " Y" ,
4949 const std::string& qkv_ln_2_out_pos = " Y" ,
5050 const std::string& matmul_type = " matmul" ,
51+ const std::string& matmul2_type = " matmul_v2" ,
5152 const std::string& mul_type = " mul" ,
5253 bool with_q_scale = true ,
5354 bool norm_before = false ,
54- const std::string& relative_type = " " )
55+ const std::string& relative_type = " " ,
56+ bool with_mask = true )
5557 : act_type_(act_type),
5658 input_pos_(input_pos),
5759 qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
5860 matmul_type_(matmul_type),
61+ matmul2_type_(matmul2_type),
5962 mul_type_(mul_type),
6063 with_q_scale_(with_q_scale),
6164 norm_before_(norm_before),
62- relative_emb_type_(relative_type) {}
65+ relative_emb_type_(relative_type),
66+ with_mask_(with_mask) {}
6367
6468 void BuildPattern () override {
6569 auto * input = VarNode (" input" )
@@ -213,18 +217,25 @@ class XPUSingleEncoderFuser : public FuseBase {
213217 ->AsIntermediate ();
214218
215219 auto * qk_matmul = OpNode (" qk_matmul" , matmul_type_)->AsIntermediate ();
220+ std::string op_after_qk_matmul = with_mask_ ? " elementwise_add" : " softmax" ;
216221 auto * qk_matmul_out = VarNode (" qk_matmul_out" )
217222 ->assert_is_op_output (matmul_type_, " Out" )
218- ->assert_is_op_input (" elementwise_add " , " X" )
223+ ->assert_is_op_input (op_after_qk_matmul , " X" )
219224 ->AsIntermediate ();
220- auto * qk_mask = VarNode (" qk_mask" )
221- ->assert_is_op_input (" elementwise_add" , " Y" )
222- ->AsInput ();
223- auto * qk_add = OpNode (" qk_add" , " elementwise_add" )->AsIntermediate ();
224- auto * qk_add_out = VarNode (" qk_add_out" )
225- ->assert_is_op_output (" elementwise_add" , " Out" )
226- ->assert_is_op_input (" softmax" , " X" )
227- ->AsIntermediate ();
225+ PMNode* qk_mask = nullptr ;
226+ PMNode* qk_add = nullptr ;
227+ PMNode* qk_add_out = nullptr ;
228+ if (with_mask_) {
229+ qk_mask = VarNode (" qk_mask" )
230+ ->assert_is_op_input (" elementwise_add" , " Y" )
231+ ->AsInput ();
232+ qk_add = OpNode (" qk_add" , " elementwise_add" )->AsIntermediate ();
233+ qk_add_out = VarNode (" qk_add_out" )
234+ ->assert_is_op_output (" elementwise_add" , " Out" )
235+ ->assert_is_op_input (" softmax" , " X" )
236+ ->AsIntermediate ();
237+ }
238+
228239 auto * qk_softmax = OpNode (" qk_softmax" , " softmax" )->AsIntermediate ();
229240 auto * qk_softmax_out = VarNode (" qk_softmax_out" )
230241 ->assert_is_op_output (" softmax" , " Out" )
@@ -256,16 +267,16 @@ class XPUSingleEncoderFuser : public FuseBase {
256267 auto * v_transpose2 = OpNode (" v_transpose2" , " transpose2" )->AsIntermediate ();
257268 auto * v_transpose2_out = VarNode (" v_transpose2_out" )
258269 ->assert_is_op_output (" transpose2" , " Out" )
259- ->assert_is_op_input (matmul_type_ , " Y" )
270+ ->assert_is_op_input (matmul2_type_ , " Y" )
260271 ->AsIntermediate ();
261272 auto * v_transpose2_xshape =
262273 VarNode (" v_transpose2_xshape" )
263274 ->assert_is_op_output (" transpose2" , " XShape" )
264275 ->AsIntermediate ();
265276
266- auto * qkv_matmul = OpNode (" qkv_matmul" , matmul_type_ )->AsIntermediate ();
277+ auto * qkv_matmul = OpNode (" qkv_matmul" , matmul2_type_ )->AsIntermediate ();
267278 auto * qkv_matmul_out = VarNode (" qkv_matmul_out" )
268- ->assert_is_op_output (matmul_type_ , " Out" )
279+ ->assert_is_op_output (matmul2_type_ , " Out" )
269280 ->assert_is_op_input (" transpose2" , " X" )
270281 ->AsIntermediate ();
271282 auto * qkv_transpose2 =
@@ -459,9 +470,14 @@ class XPUSingleEncoderFuser : public FuseBase {
459470 *k_reshape2 >> *k_reshape2_xshape;
460471 *k_transpose2 >> *k_transpose2_xshape;
461472
462- *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
463- *qk_softmax_out >> *qkv_matmul;
464- *qk_mask >> *qk_add;
473+ if (with_mask_) {
474+ *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
475+ *qk_softmax_out >> *qkv_matmul;
476+ *qk_mask >> *qk_add;
477+ } else {
478+ *qk_matmul >> *qk_matmul_out >> *qk_softmax >> *qk_softmax_out >>
479+ *qkv_matmul;
480+ }
465481
466482 if (norm_before_) {
467483 *ln_before_out >> *v_mul;
@@ -513,7 +529,9 @@ class XPUSingleEncoderFuser : public FuseBase {
513529 cpp::OpDesc op_desc;
514530 op_desc.SetType (" single_encoder" );
515531 op_desc.SetInput (" Inputs" , {matched.at (" input" )->arg ()->name });
516- op_desc.SetInput (" Mask" , {matched.at (" qk_mask" )->arg ()->name });
532+ if (with_mask_) {
533+ op_desc.SetInput (" Mask" , {matched.at (" qk_mask" )->arg ()->name });
534+ }
517535 op_desc.SetInput (" FCWeight" ,
518536 {
519537 matched.at (" q_mul_y" )->arg ()->name ,
@@ -645,7 +663,6 @@ class XPUSingleEncoderFuser : public FuseBase {
645663 single_encoder_stmt->SetOp (fake_subgraph_op);
646664
647665 std::vector<std::string> froms = {
648- " qk_mask" ,
649666 " k_mul_y" ,
650667 " v_mul_y" ,
651668 " qkv_mul_y" ,
@@ -660,6 +677,9 @@ class XPUSingleEncoderFuser : public FuseBase {
660677 " qkv_ln_2_scale" ,
661678 " qkv_ln_2_bias" ,
662679 };
680+ if (with_mask_) {
681+ froms.push_back (" qk_mask" );
682+ }
663683 if (relative_emb_type_ == " __xpu__roformer_relative_embedding" ) {
664684 froms.push_back (" q_cos_embedding" );
665685 froms.push_back (" q_sin_embedding" );
@@ -687,10 +707,12 @@ class XPUSingleEncoderFuser : public FuseBase {
687707 std::string input_pos_;
688708 std::string qkv_ln_2_out_pos_;
689709 std::string matmul_type_;
710+ std::string matmul2_type_;
690711 std::string mul_type_;
691712 bool with_q_scale_;
692713 bool norm_before_;
693714 const std::string relative_emb_type_;
715+ bool with_mask_;
694716 // quant_info: mul input_max, output_max * 6 + matmul x_max:y_max, output_max
695717 // * 2
696718 void set_quant_info (Scope* scope,
@@ -955,7 +977,7 @@ class XPUMultiEncoderFuser {
955977 std::string mask_name;
956978 for (auto * encoder : all_encoders) {
957979 auto * op_info = encoder->stmt ()->op_info ();
958- if (mask_name.empty ()) {
980+ if (mask_name.empty () && op_info-> HasInput ( " Mask " ) ) {
959981 mask_name = op_info->Input (" Mask" ).front ();
960982 } else {
961983 // CHECK(mask_name == op_info->Input("Mask").front());
@@ -1026,13 +1048,11 @@ class XPUMultiEncoderFuser {
10261048 if (all_encoders.size () == 1 ) {
10271049 // take care of only one encoder
10281050 in_name = op_info->Input (" Inputs" ).front ();
1029- mask_name = op_info->Input (" Mask" ).front ();
10301051 out_name = op_info->Output (" Outputs" ).front ();
10311052 } else if (i == 0 ) {
10321053 // first encoder
10331054 to_remove.insert (cur_out);
10341055 in_name = op_info->Input (" Inputs" ).front ();
1035- mask_name = op_info->Input (" Mask" ).front ();
10361056 } else if (i == all_encoders.size () - 1 ) {
10371057 // last encoder
10381058 to_remove.insert (cur_encoder);
@@ -1051,7 +1071,9 @@ class XPUMultiEncoderFuser {
10511071 for (auto kv : arg_map) {
10521072 op_desc.SetInput (kv.first , kv.second );
10531073 }
1054- op_desc.SetInput (" Mask" , {mask_name});
1074+ if (!mask_name.empty ()) {
1075+ op_desc.SetInput (" Mask" , {mask_name});
1076+ }
10551077 op_desc.SetOutput (" Output" , {out_name});
10561078 op_desc.SetAttr <int >(" xpu" , 1 );
10571079 op_desc.SetAttr <int >(
@@ -1382,9 +1404,11 @@ class XPUMultiEncoderFusePass : public ProgramPass {
13821404 std::vector<std::string> input_poss{" X" , " Y" };
13831405 std::vector<std::string> qkv_ln_2_out_poss{" X" , " Y" };
13841406 std::vector<std::string> matmul_types{" matmul" , " matmul_v2" };
1407+ std::vector<std::string> matmul2_types{" matmul" , " matmul_v2" };
13851408 std::vector<std::string> mul_types{" mul" , " matmul" , " matmul_v2" };
13861409 std::vector<bool > with_q_scales{true , false };
13871410 std::vector<bool > norm_befores{true , false };
1411+ std::vector<bool > with_mask{true , false };
13881412 std::vector<std::string> relative_embedding_type{
13891413 " " , " __xpu__roformer_relative_embedding" };
13901414
@@ -1423,23 +1447,29 @@ class XPUMultiEncoderFusePass : public ProgramPass {
14231447 for (auto & input_pos : input_poss) {
14241448 for (auto & qkv_ln_2_out_pos : qkv_ln_2_out_poss) {
14251449 for (auto & matmul_type : matmul_types) {
1426- for (auto & mul_type : mul_types) {
1427- for (auto with_q_scale : with_q_scales) {
1428- for (auto norm_before : norm_befores) {
1429- for (auto relative_type : relative_embedding_type) {
1430- fusion::XPUSingleEncoderFuser single_encoder_fuser (
1431- act_type,
1432- input_pos,
1433- qkv_ln_2_out_pos,
1434- matmul_type,
1435- mul_type,
1436- with_q_scale,
1437- norm_before,
1438- relative_type);
1439- single_encoder_fuser (graph.get ());
1440- fusion::XPUMultiEncoderFuser multi_encoder_fuser (
1441- fc_precision, adaptive_seqlen);
1442- multi_encoder_fuser (graph.get ());
1450+ for (auto & matmul2_type : matmul2_types) {
1451+ for (auto & mul_type : mul_types) {
1452+ for (auto with_q_scale : with_q_scales) {
1453+ for (auto norm_before : norm_befores) {
1454+ for (auto relative_type : relative_embedding_type) {
1455+ for (auto mask : with_mask) {
1456+ fusion::XPUSingleEncoderFuser single_encoder_fuser (
1457+ act_type,
1458+ input_pos,
1459+ qkv_ln_2_out_pos,
1460+ matmul_type,
1461+ matmul2_type,
1462+ mul_type,
1463+ with_q_scale,
1464+ norm_before,
1465+ relative_type,
1466+ mask);
1467+ single_encoder_fuser (graph.get ());
1468+ fusion::XPUMultiEncoderFuser multi_encoder_fuser (
1469+ fc_precision, adaptive_seqlen);
1470+ multi_encoder_fuser (graph.get ());
1471+ }
1472+ }
14431473 }
14441474 }
14451475 }
0 commit comments