Skip to content

Conversation

@griii
Copy link

@griii griii commented Sep 7, 2025

Background

During vLLM inference, some features like Cascade Attention require the LSE output from the attention mechanism.

When the FlashAttention-2 kernel operates with seqlenq_ngroups_swapped = True (the case for an inference-only batch with GQA), it performs the attention computation with an internal shape of (b, ngroups, nheads_kv, d). After the computation, it restores the softmax_lse tensor to the expected layout using the following transformation:

int64_t lse_size_before[] = {num_heads, batch_size, max_seqlen_q};
int64_t lse_size_after[] = {num_heads * max_seqlen_q, batch_size};
softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after)

The Problem with split_kv

The issue arises when the split_kv path is triggered in FlashAttention-2's flash_api.cpp. The decision to use this path is made by the num_splits_heuristic function. When split_kv is enabled, the kernel partitions the K/V tensors, stores partial LSE results in softmax_lse_accum, and finally launches a combine_attn_seqk_parallel kernel to reduce these partial results into the final LSE.

The root cause of the bug lies in the memory layout defined within the combine_attn_seqk_parallel kernel. Specifically, when seqlenq_ngroups_swapped = True and unpadded_lse = True, the layout for the output tensor gLSE_unpadded is constructed as follows:

// Note: This code is inside the combine_attn_seqk_parallel kernel
// When seqlenq_ngroups_swapped is True, then unpadded_lse is always True as well.

auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);

Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));

Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);

Because of this final_layout, the physical memory layout of gLSE_unpadded is already aligned with the desired output shape and can be correctly interpreted with a simple reshape like softmax_lse.reshape(num_heads * max_seqlen_q, batch_size).

However, the original code path unconditionally applies the .transpose(1, 2) operation, which is incorrect for the split_kv case. This erroneous transpose corrupts the LSE layout, leading to a complete precision collapse in downstream operations.

Reproducibility

This bug is hard to reproduce because it only manifests under a specific combination of conditions:

  • Using a feature that consumes the LSE output (e.g., Cascade Attention).
  • The num_splits calculated by num_splits_heuristic is greater than 1. This value is sensitive to:
    • GPU architecture (e.g., number of SMs)
    • Batch size
    • Sequence length of the KV cache for the inference requests.
    • Different LLM Model

Related Issues
vllm-project/vllm#17580
vllm-project/vllm#17652
vllm-project/vllm#17886
vllm-project/vllm#18345
vllm-project/vllm#22103

I successfully reproduced the issue using the sample program provided in vllm-project/vllm#22103 on an L20 machine, and the changes in this commit fix the problem.

@LucasWilkinson
Copy link
Collaborator

Amazing! Thank you this is much appreciated!

@danielhanchen
Copy link

danielhanchen commented Nov 2, 2025

@LucasWilkinson Is it possible to do a vLLM FA2 release on pypi to include this PR - that would be extremely helpful for the community and thank you! I think the last pypi was 2.6.2 (Sep 5) https://pypi.org/project/vllm-flash-attn

I just noticed vLLM 0.10.2 (Sept 12) actually has FA2 version 2.7.2.post1, so I'm assuming vLLM 0.10.2 and onwards already packaged this fix hopefully? I found it via import vllm.vllm_flash_attn; print(vllm.vllm_flash_attn.__version__)

Actually I found out 2.7.2.post1 is also for vLLM 0.9.2, so unsure :(

According to https://x.com/vllm_project/status/1985023958371184836, latest vLLM has the fix! Nice work!

@LucasWilkinson
Copy link
Collaborator

@danielhanchen Yes vllm-flash-attn ships inside vLLM now 👍 (the PyPi wheel is no-longer used), this is primarily easier to keep torch versions in-sync; its a bit hacky though so may go back to a dedicated wheel in the future. We don't actually rev vllm.vllm_flash_attn.__version__ although we potentially should (this just represents the upstream version we are currently based off) the hash vLLM is pinned to can be found here: https://github.com/vllm-project/vllm/blob/380ba6816d4646be99d9b6d207ba7bc7fce8290e/cmake/external_projects/vllm_flash_attn.cmake#L41

Hope that helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants