@@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
123123 fuse_qkv_split_out_v_node});
124124
125125 // core attention pattern
126+ auto * qk_scale_node =
127+ pattern->NewNode (qk_scale_op_repr ())->assert_is_op (" scale" );
128+ auto * qk_scale_out_node =
129+ pattern->NewNode (qk_scale_out_repr ())->assert_is_op_output (" scale" );
130+ fuse_qkv_split_out_q_node->assert_is_op_input (" scale" , " X" );
131+ qk_scale_node->LinksFrom ({fuse_qkv_split_out_q_node})
132+ .LinksTo ({qk_scale_out_node});
133+
126134 auto * qk_matmul_node =
127135 pattern->NewNode (qk_matmul_op_repr ())->assert_is_op (" matmul_v2" );
128136 auto * qk_matmul_out_node =
129137 pattern->NewNode (qk_matmul_out_repr ())->assert_is_op_output (" matmul_v2" );
130- fuse_qkv_split_out_q_node ->assert_is_op_input (" matmul_v2" , " X" );
138+ qk_scale_out_node ->assert_is_op_input (" matmul_v2" , " X" );
131139 fuse_qkv_split_out_k_node->assert_is_op_input (" matmul_v2" , " Y" );
132- qk_matmul_node
133- ->LinksFrom ({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node})
140+ qk_matmul_node->LinksFrom ({qk_scale_out_node, fuse_qkv_split_out_k_node})
134141 .LinksTo ({qk_matmul_out_node});
135142
136- auto * qk_scale_node =
137- pattern->NewNode (qk_scale_op_repr ())->assert_is_op (" scale" );
138- auto * qk_scale_out_node =
139- pattern->NewNode (qk_scale_out_repr ())->assert_is_op_output (" scale" );
140- qk_matmul_out_node->assert_is_op_input (" scale" , " X" );
141- qk_scale_node->LinksFrom ({qk_matmul_out_node}).LinksTo ({qk_scale_out_node});
142-
143143 PDNode* add_mask_ele_add_out_node{nullptr };
144144 if (has_attn_mask) {
145145 auto * add_mask_ele_add_node = pattern->NewNode (add_mask_ele_add_op_repr ())
@@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
149149 ->assert_is_op_input (" elementwise_add" , " Y" );
150150 add_mask_ele_add_out_node = pattern->NewNode (add_mask_ele_add_out_repr ())
151151 ->assert_is_op_output (" elementwise_add" );
152- qk_scale_out_node ->assert_is_op_input (" elementwise_add" , " X" );
152+ qk_matmul_out_node ->assert_is_op_input (" elementwise_add" , " X" );
153153 add_mask_ele_add_node
154- ->LinksFrom ({qk_scale_out_node , add_mask_ele_add_mask_node})
154+ ->LinksFrom ({qk_matmul_out_node , add_mask_ele_add_mask_node})
155155 .LinksTo ({add_mask_ele_add_out_node});
156156 }
157157
@@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
164164 qk_softmax_node->LinksFrom ({add_mask_ele_add_out_node})
165165 .LinksTo ({qk_softmax_out_node});
166166 } else {
167- qk_scale_out_node ->assert_is_op_input (" softmax" , " X" );
168- qk_softmax_node->LinksFrom ({qk_scale_out_node })
167+ qk_matmul_out_node ->assert_is_op_input (" softmax" , " X" );
168+ qk_softmax_node->LinksFrom ({qk_matmul_out_node })
169169 .LinksTo ({qk_softmax_out_node});
170170 }
171171
@@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
575575 .LinksTo ({add_mask_ele_add_grad_x_grad_node});
576576 }
577577
578- PDNode* qk_scale_grad_input_node =
578+ PDNode* qk_matmul_grad_input_node =
579579 has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out;
580- auto * qk_scale_grad_node =
581- pattern->NewNode (qk_scale_grad_op_repr ())->assert_is_op (" scale" );
582- auto * qk_scale_grad_out_node =
583- pattern->NewNode (qk_scale_grad_out_repr ())->assert_is_op_output (" scale" );
584- qk_scale_grad_input_node->assert_is_op_input (" scale" , " X" );
585- qk_scale_grad_node->LinksFrom ({qk_scale_grad_input_node})
586- .LinksTo ({qk_scale_grad_out_node});
587-
588580 auto * qk_matmul_grad_node = pattern->NewNode (qk_matmul_grad_op_repr ())
589581 ->assert_is_op (" matmul_v2_grad" );
590582 auto * qk_matmul_grad_x_node = pattern->NewNode (qk_matmul_grad_x_repr ())
@@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
597589 auto * qk_matmul_grad_w_grad_node =
598590 pattern->NewNode (qk_matmul_grad_w_grad_repr ())
599591 ->assert_is_op_output (" matmul_v2_grad" , " Y@GRAD" );
600- qk_scale_grad_out_node ->assert_is_op_input (" matmul_v2_grad" , " Out@GRAD" );
592+ qk_matmul_grad_input_node ->assert_is_op_input (" matmul_v2_grad" , " Out@GRAD" );
601593 qk_matmul_grad_node
602- ->LinksFrom ({qk_scale_grad_out_node ,
594+ ->LinksFrom ({qk_matmul_grad_input_node ,
603595 qk_matmul_grad_x_node,
604596 qk_matmul_grad_w_node})
605597 .LinksTo ({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node});
606598
599+ auto * qk_scale_grad_node =
600+ pattern->NewNode (qk_scale_grad_op_repr ())->assert_is_op (" scale" );
601+ auto * qk_scale_grad_out_node =
602+ pattern->NewNode (qk_scale_grad_out_repr ())->assert_is_op_output (" scale" );
603+ qk_matmul_grad_x_grad_node->assert_is_op_input (" scale" , " X" );
604+ qk_scale_grad_node->LinksFrom ({qk_matmul_grad_x_grad_node})
605+ .LinksTo ({qk_scale_grad_out_node});
606+
607607 // fuse qkv projection
608608 auto * fuse_qkv_split_grad_node =
609609 pattern->NewNode (fuse_qkv_split_grad_op_repr ())->assert_is_op (" concat" );
610610 auto * fuse_qkv_split_grad_out_node =
611611 pattern->NewNode (fuse_qkv_split_grad_out_repr ())
612612 ->assert_is_op_output (" concat" );
613- qk_matmul_grad_x_grad_node ->assert_is_op_input (" concat" ); // q grad
613+ qk_scale_grad_out_node ->assert_is_op_input (" concat" ); // q grad
614614 qk_matmul_grad_w_grad_node->assert_is_op_input (" concat" ); // k grad
615615 qkv_matmul_grad_w_grad_node->assert_is_op_input (" concat" ); // v grad
616616 fuse_qkv_split_grad_node
617- ->LinksFrom ({qk_matmul_grad_x_grad_node ,
617+ ->LinksFrom ({qk_scale_grad_out_node ,
618618 qk_matmul_grad_w_grad_node,
619619 qkv_matmul_grad_w_grad_node})
620620 .LinksTo ({fuse_qkv_split_grad_out_node});
@@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
894894 fused_attention_op_desc.SetAttr (" transpose_qkv_wb" , true );
895895 std::vector<int > shape = PADDLE_GET_CONST (
896896 std::vector<int >, fuse_qkv_reshape_op_node->Op ()->GetAttr (" shape" ));
897- fused_attention_op_desc.SetAttr (" num_heads" , shape[2 ]);
897+ fused_attention_op_desc.SetAttr (" num_heads" , shape[2 ] / 3 );
898898 GET_IR_NODE_FROM_SUBGRAPH (
899899 fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
900900 GET_IR_NODE_FROM_SUBGRAPH (fuse_qkv_ele_add_bias_node,
@@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
13371337 std::vector<int > shape =
13381338 PADDLE_GET_CONST (std::vector<int >,
13391339 fuse_qkv_reshape_grad_op_node->Op ()->GetAttr (" shape" ));
1340- fused_attention_grad_op_desc.SetAttr (" num_heads" , shape[2 ]);
1340+ fused_attention_grad_op_desc.SetAttr (" num_heads" , shape[2 ] / 3 );
13411341 fused_attention_grad_op_desc.SetAttr (" pre_layer_norm" , true );
13421342 fused_attention_grad_op_desc.SetAttr (" transpose_qkv_wb" , true );
13431343
0 commit comments