Skip to content

[Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" leading to gibberish#37054

Merged
MatthewBonanni merged 12 commits intovllm-project:mainfrom
andylolu2:andy/fi-bf16kv-bugfix
Mar 18, 2026
Merged

[Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" leading to gibberish#37054
MatthewBonanni merged 12 commits intovllm-project:mainfrom
andylolu2:andy/fi-bf16kv-bugfix

Conversation

@andylolu2
Copy link
Copy Markdown
Contributor

@andylolu2 andylolu2 commented Mar 14, 2026

Purpose

This PR fixes the following issues:

  • FlashInfer + kv_cache_dtype "auto" generates giberrish when layer._[qkv]_scale != 1.0
    • Bug is FI applies the layer._[qkv]_scale unconditionally, even when the QKV values are in unscaled bf16.
    • This applies to both normal & MLA attention paths.
  • KV cache scales not properly handed when using MLA + fp8.
    • In MLA, the KV latents necessarily must use the same quantization scale for K & V, so only one of layer._k_scale or layer._v_scale should be used, not both. The current implementation sometimes assumes layer._k_scale is used, other times assumes layer._v_scale or layer._k_scale * layer._v_scale is used, which is inconsistent and leads to bad generations.
    • In this PR I choose the only use layer._k_scale. layer._v_scale is completely ignored when using MLA.
  • The CUTLASS_MLA backend says it supports fp8 kv cache but there's no logic to handle the quantization scales properly, so disabling its support for fp8 for now until it gets implemented.

To summarize the situation of fp8 MLA scales:

  • q_scale -> Meant for quantizing q_mqa, not q_mha. q_mha currently has no corresponding scales and is naively casted to fp8 if use_fp8_prefill (code reference).
  • k_scale -> Meant for quantizing kv_latents, not k_mha or v_mha. k_mha and v_mha currently has no corresponding scales and is naively casted to fp8 if use_fp8_prefill (code reference).
  • v_scale -> Completely unused.

Test Plan

The current tests are passing because the qkv_scales are mocked to be 1.0, which silently avoids this bug. Updated the tests to remove the assumption that [qkv]_scales are 1.0.

To assert the layer._v_scale is not used in MLA, I set it to NaN in the mla tests to ensure wrong results if they are ever used.

Test Result

Updated tests are passing.

@mergify mergify bot added nvidia v1 bug Something isn't working labels Mar 14, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug where FlashInfer attention backends incorrectly applied quantization scales even when the KV cache was not FP8-quantized, resulting in corrupted output. The fix correctly restricts the application of these scales to cases where the KV cache data type is indeed FP8. The accompanying test modifications, which set mock layer scales to non-unity values, are appropriate for verifying the fix. The changes appear correct and effectively resolve the issue.

Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Andy Lo <andy@mistral.ai>
@andylolu2 andylolu2 force-pushed the andy/fi-bf16kv-bugfix branch from 223a57d to 6b89f72 Compare March 14, 2026 16:16
Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Andy Lo <andy@mistral.ai>
@andylolu2 andylolu2 force-pushed the andy/fi-bf16kv-bugfix branch from 0da3fd3 to 4787ef3 Compare March 14, 2026 16:20
@andylolu2 andylolu2 changed the title [Bugfix] FlashInfer kv_cache_dtype "auto" generates giberrish when layer._[qkv]_scale != 1.0 [Bugfix] Fix KV scales in fp8 MLA & FlashInfer kv_cache_dtype "auto" Mar 14, 2026
Signed-off-by: Andy Lo <andy@mistral.ai>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 14, 2026

Documentation preview: https://vllm--37054.org.readthedocs.build/en/37054/

@andylolu2 andylolu2 changed the title [Bugfix] Fix KV scales in fp8 MLA & FlashInfer kv_cache_dtype "auto" [Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" Mar 14, 2026
@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 14, 2026
@andylolu2 andylolu2 changed the title [Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" [Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" leading to gibberish Mar 14, 2026
@andylolu2 andylolu2 marked this pull request as ready for review March 14, 2026 16:33
@andylolu2
Copy link
Copy Markdown
Contributor Author

@gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses several correctness issues related to FP8 quantization scales in FlashInfer and MLA backends. The changes correctly handle scales for decode paths and disable FP8 support for the broken CUTLASS MLA backend. However, the fix is incomplete as the MLA prefill path for FlashInfer and Triton backends still lacks proper FP8 scale handling, which is a critical issue. I've added a comment with details on the missing fix.

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 16, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @andylolu2.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 16, 2026
@mergify mergify bot removed the needs-rebase label Mar 17, 2026
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! Just some small comments

Signed-off-by: Andy Lo <andy@mistral.ai>
@andylolu2 andylolu2 force-pushed the andy/fi-bf16kv-bugfix branch from 22131c1 to c5a76dd Compare March 18, 2026 18:09
Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Andy Lo <andy@mistral.ai>
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix and thanks for improving the test coverage!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 18, 2026
@MatthewBonanni MatthewBonanni added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
@MatthewBonanni MatthewBonanni enabled auto-merge (squash) March 18, 2026 20:14
@MatthewBonanni MatthewBonanni merged commit 577df69 into vllm-project:main Mar 18, 2026
58 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 18, 2026
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm that referenced this pull request Apr 4, 2026
…_dtype "auto" leading to gibberish (vllm-project#37054)

Signed-off-by: Andy Lo <andy@mistral.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working documentation Improvements or additions to documentation nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants