@@ -59,7 +59,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
5959 const int64_t num_heads = dims[1 ];
6060 const int64_t head_size_og = dout.dims ()[2 ];
6161 const int64_t head_size = dims[2 ];
62- const int64_t total_k = k.dims [0 ];
62+ const int64_t total_k = k.dims () [0 ];
6363 const int64_t num_heads_k = k.dims ()[1 ];
6464
6565 bool is_mha = (num_heads == num_heads_k);
@@ -80,7 +80,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
8080 total_k, num_heads_k, num_heads / num_heads_k, head_size};
8181
8282 DenseTensor dk_tmp;
83- if (dk) {
83+ if (dk && is_mha ) {
8484 ctx.template Alloc <T>(dk);
8585 dk_ptr = dk->data ();
8686 } else {
@@ -89,7 +89,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
8989 }
9090
9191 DenseTensor dv_tmp;
92- if (dv) {
92+ if (dv && is_mha ) {
9393 ctx.template Alloc <T>(dv);
9494 dv_ptr = dv->data ();
9595 } else {
@@ -219,7 +219,7 @@ void FlashAttnGradKernel(const Context& ctx,
219219 DenseTensor dk_tmp;
220220 std::initializer_list<int64_t > dk_dv_shape = {
221221 batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
222- if (dk) {
222+ if (dk && is_mha ) {
223223 ctx.template Alloc <T>(dk);
224224 dk_ptr = dk->data ();
225225 } else {
@@ -228,7 +228,7 @@ void FlashAttnGradKernel(const Context& ctx,
228228 }
229229
230230 DenseTensor dv_tmp;
231- if (dv) {
231+ if (dv && is_mha ) {
232232 ctx.template Alloc <T>(dv);
233233 dv_ptr = dv->data ();
234234 } else {
0 commit comments