We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f0edf24 commit 5ff4bbfCopy full SHA for 5ff4bbf
1 file changed
csrc/flash_attn/flash_attn.cpp
@@ -426,7 +426,7 @@ bool flash_attn_bwd(
426
// 1) num_splits == 1
427
// 2) num_splits == 0 for auto calculation, result to num_splits == 1
428
// we do allocation for case 2 for simplicity
429
- if (num_splits == 1) {
+ if (num_splits == 1 && !loop) {
430
*workspace_size = 0;
431
} else {
432
*workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float);
0 commit comments