Skip to content

Commit badd90a

Browse files
committed
fix GQA bug
1 parent 114c152 commit badd90a

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
5959
const int64_t num_heads = dims[1];
6060
const int64_t head_size_og = dout.dims()[2];
6161
const int64_t head_size = dims[2];
62-
const int64_t total_k = k.dims[0];
62+
const int64_t total_k = k.dims()[0];
6363
const int64_t num_heads_k = k.dims()[1];
6464

6565
bool is_mha = (num_heads == num_heads_k);
@@ -80,7 +80,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
8080
total_k, num_heads_k, num_heads / num_heads_k, head_size};
8181

8282
DenseTensor dk_tmp;
83-
if (dk) {
83+
if (dk && is_mha) {
8484
ctx.template Alloc<T>(dk);
8585
dk_ptr = dk->data();
8686
} else {
@@ -89,7 +89,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
8989
}
9090

9191
DenseTensor dv_tmp;
92-
if (dv) {
92+
if (dv && is_mha) {
9393
ctx.template Alloc<T>(dv);
9494
dv_ptr = dv->data();
9595
} else {
@@ -219,7 +219,7 @@ void FlashAttnGradKernel(const Context& ctx,
219219
DenseTensor dk_tmp;
220220
std::initializer_list<int64_t> dk_dv_shape = {
221221
batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
222-
if (dk) {
222+
if (dk && is_mha) {
223223
ctx.template Alloc<T>(dk);
224224
dk_ptr = dk->data();
225225
} else {
@@ -228,7 +228,7 @@ void FlashAttnGradKernel(const Context& ctx,
228228
}
229229

230230
DenseTensor dv_tmp;
231-
if (dv) {
231+
if (dv && is_mha) {
232232
ctx.template Alloc<T>(dv);
233233
dv_ptr = dv->data();
234234
} else {

0 commit comments

Comments
 (0)