Skip to content

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Sep 30, 2025

Purpose

This PR refactors the MLACommonMetadataBuilder to easily support spec decode kernel optimization in MLA implementations. This is used to enable FlashInfer-MLA support using the trtllm-gen kernels which have explicit support for spec-as-decode.

Test Plan

I ran a suite of evals over nvidia/DeepSeek-R1-FP4 and deepseek-ai/DeepSeek-R1-0528 on 4xB200 and 8xB200 respectively, using Cutlass-MLA and FlashInfer-MLA backends. Running MTP with FP4 on B200 requires the fix in #25987.

lm_eval \
  --model local-completions \
  --tasks gsm8k \
  --model_args base_url=http://0.0.0.0:8049/v1/completions,model=nvidia/DeepSeek-R1-FP4,tokenized_requests=False,tokenizer_backend=None,num_concurrent=128,timeout=120,max_retries=5

Known issues

The Cutlass-MLA backend produces incorrect output when using speculative decoding. It is not clear to my why this happens, I have debugged with enforce-eager but did not identify any issues except incorrect model output. I have not verified if this also occurs on H200, but I believe FLASH_ATTN_MLA is also an option on Hopper so it may be sufficient to deprecate Cutlass-MLA when speculative decoding is enabled.

See #26042 for tracking on this correctness issue, which seems to indicate the root cause is MLA chunked prefill.

The fix is in #26063. I will rerun the experiments for a better baseline, but the correctness of this branch for MTP is still valid.

Test Result

4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA MTP=3

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

Finish: 1319/1319 [01:07<00:00, 19.62it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9500|±  | 0.006|
|     |       |strict-match    |     5|exact_match|↑  |0.9507|±  | 0.006|

4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA MTP=3

VLLM_ATTENTION_BACKEND=CUTLASS_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

FAILED. Also fails with --enforce-eager
Finish: 1319/1319 [05:55<00:00,  3.71it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |    0|±  |     0|
|     |       |strict-match    |     5|exact_match|↑  |    0|±  |     0|

4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA No-Spec

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:30<00:00, 14.62it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9447|±  |0.0063|
|     |       |strict-match    |     5|exact_match|↑  |0.9454|±  |0.0063|

4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA No-Spec

VLLM_ATTENTION_BACKEND=CUTLASS_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:29<00:00, 14.70it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9492|±  |0.0060|
|     |       |strict-match    |     5|exact_match|↑  |0.9477|±  |0.0061|

8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA MTP=3

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

Finish: 1319/1319 [01:24<00:00, 15.67it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9568|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9530|±  |0.0058|

8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA MTP=3

VLLM_ATTENTION_BACKEND=CUTLASS_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

FAIL

8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA No-Spec

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:44<00:00, 12.62it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9545|±  |0.0057|
|     |       |strict-match    |     5|exact_match|↑  |0.9522|±  |0.0059|

8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA No-Spec

VLLM_ATTENTION_BACKEND=CUTLASS_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:42<00:00, 12.86it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9560|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

Signed-off-by: Benjamin Chislett <[email protected]>
@mergify mergify bot added the v1 label Sep 30, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MLA backend to support speculative decoding with FlashInfer, which is a great improvement. The changes are mostly well-structured. However, I found a critical issue in the fallback logic for handling non-uniform query lengths in FlashInferMLAImpl, which could lead to a runtime error. My review includes a suggestion to fix this.

@benchislett
Copy link
Collaborator Author

Update: the failed baseline is most likely due to an unknown bug in MLA chunked prefill logic. See #26042

# `reorder_batch_threshold > 1`, any decode requests which do not
# have the same query length as the first decode request will
# fall back to the prefill kernel.
supports_nonuniform_decode: ClassVar[bool] = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this needed if its always set to false? (I think we should set this for FlashAttnMLA since it does support supports_nonuniform_decode)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we maybe can actually just unify supports_spec_as_decode and supports_nonuniform_decode to supports_only_uniform_spec_decode and when thats False we just leave reorder_batch_threshold untouched and require_uniform = False

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson I'm pretty sure there can be a full matrix of options here, and that different combinations are useful. For example:

  • supports_spec_as_decode and supports_nonuniform_decode: FlashAttnMLA, where require_uniform=False is correct (it can handle varlen), and the long reorder_batch_threshold allows it to handle spec requests.
  • supports_spec_as_decode and not supports_nonuniform_decode, where require_uniform=True is required to function correctly, but reorder_batch_threshold can be overridden to = 1 + num_spec_tokens to handle spec decoding.
  • not supports_spec_as_decode and not supports_nonuniform_decode is the default for the backends which require q_len == 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update FlashAttnMLA to reflect the correct defaults, but I don't know how to support each of these 3 cases cleanly with only a single flag. Let me know if you would still prefer a different interface.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the case FlashAttnMLA case the reorder threshold is already high enough we dont need to adjust reorder_batch_threshold when spec-decoding is turned on; my suspicion would be that if a backend supports_nonuniform_decode we should just set the reorder_batch_threshold >= 8ish so that we capture the spec-decode naturally (FlashAttnMLA is really the only example of this currently)

I think if backend that supports_nonuniform_decode but also benefits from dynamically adjusting reorder_batch_threshold comes along then we could add this flag back; but just seems like unnecessary complexity currently (imo)

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin mgoin added speculative-decoding ready ONLY add when PR is ready to merge/full CI is needed labels Oct 6, 2025
@mergify
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 6, 2025
Signed-off-by: Benjamin Chislett <[email protected]>
@mergify mergify bot removed the needs-rebase label Oct 7, 2025
@benchislett benchislett merged commit 3d1f676 into vllm-project:main Oct 7, 2025
47 checks passed
@benchislett benchislett deleted the flashinfer-mla-spec-compat branch October 7, 2025 20:06
mrasquinha-g pushed a commit to mrasquinha-g/vllm that referenced this pull request Oct 9, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants