Skip to content

Commit 5ff4bbf

Browse files
authored
fix loop cond for num_split=1 (#5)
1 parent f0edf24 commit 5ff4bbf

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

csrc/flash_attn/flash_attn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ bool flash_attn_bwd(
426426
// 1) num_splits == 1
427427
// 2) num_splits == 0 for auto calculation, result to num_splits == 1
428428
// we do allocation for case 2 for simplicity
429-
if (num_splits == 1) {
429+
if (num_splits == 1 && !loop) {
430430
*workspace_size = 0;
431431
} else {
432432
*workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float);

0 commit comments

Comments
 (0)