Skip to content

Commit fcec564

Browse files
authored
Fused attn pass single ut (#50227)
1 parent 8fb2dce commit fcec564

File tree

2 files changed

+71
-57
lines changed

2 files changed

+71
-57
lines changed

paddle/fluid/framework/ir/fused_attention_pass.cc

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
123123
fuse_qkv_split_out_v_node});
124124

125125
// core attention pattern
126+
auto* qk_scale_node =
127+
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
128+
auto* qk_scale_out_node =
129+
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
130+
fuse_qkv_split_out_q_node->assert_is_op_input("scale", "X");
131+
qk_scale_node->LinksFrom({fuse_qkv_split_out_q_node})
132+
.LinksTo({qk_scale_out_node});
133+
126134
auto* qk_matmul_node =
127135
pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2");
128136
auto* qk_matmul_out_node =
129137
pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2");
130-
fuse_qkv_split_out_q_node->assert_is_op_input("matmul_v2", "X");
138+
qk_scale_out_node->assert_is_op_input("matmul_v2", "X");
131139
fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y");
132-
qk_matmul_node
133-
->LinksFrom({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node})
140+
qk_matmul_node->LinksFrom({qk_scale_out_node, fuse_qkv_split_out_k_node})
134141
.LinksTo({qk_matmul_out_node});
135142

136-
auto* qk_scale_node =
137-
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
138-
auto* qk_scale_out_node =
139-
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
140-
qk_matmul_out_node->assert_is_op_input("scale", "X");
141-
qk_scale_node->LinksFrom({qk_matmul_out_node}).LinksTo({qk_scale_out_node});
142-
143143
PDNode* add_mask_ele_add_out_node{nullptr};
144144
if (has_attn_mask) {
145145
auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr())
@@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
149149
->assert_is_op_input("elementwise_add", "Y");
150150
add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr())
151151
->assert_is_op_output("elementwise_add");
152-
qk_scale_out_node->assert_is_op_input("elementwise_add", "X");
152+
qk_matmul_out_node->assert_is_op_input("elementwise_add", "X");
153153
add_mask_ele_add_node
154-
->LinksFrom({qk_scale_out_node, add_mask_ele_add_mask_node})
154+
->LinksFrom({qk_matmul_out_node, add_mask_ele_add_mask_node})
155155
.LinksTo({add_mask_ele_add_out_node});
156156
}
157157

@@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
164164
qk_softmax_node->LinksFrom({add_mask_ele_add_out_node})
165165
.LinksTo({qk_softmax_out_node});
166166
} else {
167-
qk_scale_out_node->assert_is_op_input("softmax", "X");
168-
qk_softmax_node->LinksFrom({qk_scale_out_node})
167+
qk_matmul_out_node->assert_is_op_input("softmax", "X");
168+
qk_softmax_node->LinksFrom({qk_matmul_out_node})
169169
.LinksTo({qk_softmax_out_node});
170170
}
171171

@@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
575575
.LinksTo({add_mask_ele_add_grad_x_grad_node});
576576
}
577577

578-
PDNode* qk_scale_grad_input_node =
578+
PDNode* qk_matmul_grad_input_node =
579579
has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out;
580-
auto* qk_scale_grad_node =
581-
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
582-
auto* qk_scale_grad_out_node =
583-
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
584-
qk_scale_grad_input_node->assert_is_op_input("scale", "X");
585-
qk_scale_grad_node->LinksFrom({qk_scale_grad_input_node})
586-
.LinksTo({qk_scale_grad_out_node});
587-
588580
auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr())
589581
->assert_is_op("matmul_v2_grad");
590582
auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr())
@@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
597589
auto* qk_matmul_grad_w_grad_node =
598590
pattern->NewNode(qk_matmul_grad_w_grad_repr())
599591
->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
600-
qk_scale_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
592+
qk_matmul_grad_input_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
601593
qk_matmul_grad_node
602-
->LinksFrom({qk_scale_grad_out_node,
594+
->LinksFrom({qk_matmul_grad_input_node,
603595
qk_matmul_grad_x_node,
604596
qk_matmul_grad_w_node})
605597
.LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node});
606598

599+
auto* qk_scale_grad_node =
600+
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
601+
auto* qk_scale_grad_out_node =
602+
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
603+
qk_matmul_grad_x_grad_node->assert_is_op_input("scale", "X");
604+
qk_scale_grad_node->LinksFrom({qk_matmul_grad_x_grad_node})
605+
.LinksTo({qk_scale_grad_out_node});
606+
607607
// fuse qkv projection
608608
auto* fuse_qkv_split_grad_node =
609609
pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat");
610610
auto* fuse_qkv_split_grad_out_node =
611611
pattern->NewNode(fuse_qkv_split_grad_out_repr())
612612
->assert_is_op_output("concat");
613-
qk_matmul_grad_x_grad_node->assert_is_op_input("concat"); // q grad
613+
qk_scale_grad_out_node->assert_is_op_input("concat"); // q grad
614614
qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad
615615
qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad
616616
fuse_qkv_split_grad_node
617-
->LinksFrom({qk_matmul_grad_x_grad_node,
617+
->LinksFrom({qk_scale_grad_out_node,
618618
qk_matmul_grad_w_grad_node,
619619
qkv_matmul_grad_w_grad_node})
620620
.LinksTo({fuse_qkv_split_grad_out_node});
@@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
894894
fused_attention_op_desc.SetAttr("transpose_qkv_wb", true);
895895
std::vector<int> shape = PADDLE_GET_CONST(
896896
std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape"));
897-
fused_attention_op_desc.SetAttr("num_heads", shape[2]);
897+
fused_attention_op_desc.SetAttr("num_heads", shape[2] / 3);
898898
GET_IR_NODE_FROM_SUBGRAPH(
899899
fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
900900
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node,
@@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
13371337
std::vector<int> shape =
13381338
PADDLE_GET_CONST(std::vector<int>,
13391339
fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape"));
1340-
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2]);
1340+
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2] / 3);
13411341
fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true);
13421342
fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true);
13431343

python/paddle/fluid/tests/unittests/test_fused_attention_pass.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353

5454
self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim)
5555
self.out_proj = paddle.nn.Linear(embed_dim, embed_dim)
56-
self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train")
56+
self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train")
5757

5858
def forward(self, x, attn_mask=None):
5959
residual = x
@@ -64,13 +64,13 @@ def forward(self, x, attn_mask=None):
6464

6565
# compute qkv
6666
qkv = self.qkv_proj(x)
67-
qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim])
67+
qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim])
6868
qkv = paddle.transpose(qkv, [0, 2, 1, 3])
69-
q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1)
69+
q, k, v = paddle.split(qkv, num_or_sections=3, axis=1)
7070

7171
# compute core attention
72+
q = paddle.scale(q, scale=self.head_dim**-0.5)
7273
product = paddle.matmul(x=q, y=k, transpose_y=True)
73-
product = paddle.scale(product, scale=self.head_dim**-0.5)
7474
if attn_mask is not None:
7575
product = product + attn_mask
7676
weights = F.softmax(product)
@@ -104,21 +104,28 @@ def setUp(self):
104104
self.pre_ln = True
105105
self.attn_dropout = True
106106
self.add_mask = True
107+
self.x_data = None
108+
self.mask_data = None
107109

108-
def test_pass(self):
110+
def get_rst(self, use_pass=False):
109111
batch_size = 2
110112
seq_len = 1024
111113
hidden_size = 768
112114
num_heads = 12
113115

114-
x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32')
115-
mask_data = np.random.rand(
116-
batch_size, num_heads, seq_len, seq_len
117-
).astype('float32')
116+
np.random.seed(1234)
117+
if self.x_data is None:
118+
self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype(
119+
'float32'
120+
)
121+
self.mask_data = np.random.rand(
122+
batch_size, num_heads, seq_len, seq_len
123+
).astype('float32')
118124

119125
main_prog = paddle.static.Program()
120126
main_prog.random_seed = 1234
121127
startup_prog = paddle.static.Program()
128+
startup_prog.random_seed = 1234
122129

123130
with paddle.static.program_guard(main_prog, startup_prog):
124131
data = paddle.static.data(
@@ -150,29 +157,36 @@ def test_pass(self):
150157
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
151158
sgd_optimizer.minimize(loss)
152159

153-
pass_manager = PassManager([new_pass("fused_attention")])
154-
pass_manager.apply([main_prog], [startup_prog])
155-
156-
ops = main_prog.global_block().ops
157-
assert ops[2].type == 'fused_attention'
158-
assert ops[3].type == 'reduce_mean'
159-
assert ops[5].type == 'reduce_mean_grad'
160-
assert ops[6].type == 'fused_attention_grad'
161-
# two ops for linear, one op for reduce mean
162-
# one fill constant
163-
# one op for reduce mean grad, two ops for linear bwd
164-
# the eighth op should be the optimizer
165-
assert ops[9].type == 'sgd'
160+
if use_pass:
161+
pass_manager = PassManager([new_pass("fused_attention")])
162+
pass_manager.apply([main_prog], [startup_prog])
163+
164+
ops = main_prog.global_block().ops
165+
assert ops[2].type == 'fused_attention'
166+
assert ops[3].type == 'reduce_mean'
167+
assert ops[5].type == 'reduce_mean_grad'
168+
assert ops[6].type == 'fused_attention_grad'
169+
# two ops for linear, one op for reduce mean
170+
# one fill constant
171+
# one op for reduce mean grad, two ops for linear bwd
172+
# the eighth op should be the optimizer
173+
assert ops[9].type == 'sgd'
166174

167175
exe = paddle.static.Executor()
168176
exe.run(startup_prog)
169-
rst = exe.run(
170-
main_prog,
171-
feed={'x': x_data, 'attn_mask': mask_data},
172-
fetch_list=[loss],
173-
)
177+
for i in range(2):
178+
rst = exe.run(
179+
main_prog,
180+
feed={'x': self.x_data, 'attn_mask': self.mask_data},
181+
fetch_list=[loss],
182+
)
183+
return rst
184+
185+
def test_pass(self):
186+
fused_rst = self.get_rst(use_pass=True)
187+
non_fused_rst = self.get_rst()
188+
assert np.allclose(fused_rst, non_fused_rst)
174189

175190

176191
if __name__ == "__main__":
177-
np.random.seed(0)
178192
unittest.main()

0 commit comments

Comments
 (0)