Fix LSE output error in FA2 kv-split #87
Merged
+12
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Background
During vLLM inference, some features like Cascade Attention require the LSE output from the attention mechanism.
When the FlashAttention-2 kernel operates with
seqlenq_ngroups_swapped = True(the case for an inference-only batch with GQA), it performs the attention computation with an internal shape of(b, ngroups, nheads_kv, d). After the computation, it restores thesoftmax_lsetensor to the expected layout using the following transformation:The Problem with split_kv
The issue arises when the split_kv path is triggered in FlashAttention-2's flash_api.cpp. The decision to use this path is made by the
num_splits_heuristicfunction. When split_kv is enabled, the kernel partitions the K/V tensors, stores partial LSE results in softmax_lse_accum, and finally launches acombine_attn_seqk_parallelkernel to reduce these partial results into the final LSE.The root cause of the bug lies in the memory layout defined within the combine_attn_seqk_parallel kernel. Specifically, when
seqlenq_ngroups_swapped = Trueandunpadded_lse = True, the layout for the output tensor gLSE_unpadded is constructed as follows:Because of this final_layout, the physical memory layout of gLSE_unpadded is already aligned with the desired output shape and can be correctly interpreted with a simple reshape like
softmax_lse.reshape(num_heads * max_seqlen_q, batch_size).However, the original code path unconditionally applies the
.transpose(1, 2)operation, which is incorrect for the split_kv case. This erroneous transpose corrupts the LSE layout, leading to a complete precision collapse in downstream operations.Reproducibility
This bug is hard to reproduce because it only manifests under a specific combination of conditions:
num_splits_heuristicis greater than 1. This value is sensitive to:Related Issues
vllm-project/vllm#17580
vllm-project/vllm#17652
vllm-project/vllm#17886
vllm-project/vllm#18345
vllm-project/vllm#22103
I successfully reproduced the issue using the sample program provided in vllm-project/vllm#22103 on an L20 machine, and the changes in this commit fix the problem.