-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA #25984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA #25984
Conversation
Signed-off-by: Benjamin Chislett <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the MLA backend to support speculative decoding with FlashInfer, which is a great improvement. The changes are mostly well-structured. However, I found a critical issue in the fallback logic for handling non-uniform query lengths in FlashInferMLAImpl, which could lead to a runtime error. My review includes a suggestion to fix this.
|
Update: the failed baseline is most likely due to an unknown bug in MLA chunked prefill logic. See #26042 |
| # `reorder_batch_threshold > 1`, any decode requests which do not | ||
| # have the same query length as the first decode request will | ||
| # fall back to the prefill kernel. | ||
| supports_nonuniform_decode: ClassVar[bool] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this needed if its always set to false? (I think we should set this for FlashAttnMLA since it does support supports_nonuniform_decode)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we maybe can actually just unify supports_spec_as_decode and supports_nonuniform_decode to supports_only_uniform_spec_decode and when thats False we just leave reorder_batch_threshold untouched and require_uniform = False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson I'm pretty sure there can be a full matrix of options here, and that different combinations are useful. For example:
supports_spec_as_decode and supports_nonuniform_decode: FlashAttnMLA, whererequire_uniform=Falseis correct (it can handle varlen), and the longreorder_batch_thresholdallows it to handle spec requests.supports_spec_as_decode and not supports_nonuniform_decode, whererequire_uniform=Trueis required to function correctly, butreorder_batch_thresholdcan be overridden to= 1 + num_spec_tokensto handle spec decoding.not supports_spec_as_decode and not supports_nonuniform_decodeis the default for the backends which require q_len == 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update FlashAttnMLA to reflect the correct defaults, but I don't know how to support each of these 3 cases cleanly with only a single flag. Let me know if you would still prefer a different interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for the case FlashAttnMLA case the reorder threshold is already high enough we dont need to adjust reorder_batch_threshold when spec-decoding is turned on; my suspicion would be that if a backend supports_nonuniform_decode we should just set the reorder_batch_threshold >= 8ish so that we capture the spec-decode naturally (FlashAttnMLA is really the only example of this currently)
I think if backend that supports_nonuniform_decode but also benefits from dynamically adjusting reorder_batch_threshold comes along then we could add this flag back; but just seems like unnecessary complexity currently (imo)
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me 👍 left one follow-up: https://github.com/vllm-project/vllm/pull/25984/files#r2408516263
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Dhruvil Bhatt <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <[email protected]>
Purpose
This PR refactors the
MLACommonMetadataBuilderto easily support spec decode kernel optimization in MLA implementations. This is used to enable FlashInfer-MLA support using the trtllm-gen kernels which have explicit support for spec-as-decode.Test Plan
I ran a suite of evals over
nvidia/DeepSeek-R1-FP4anddeepseek-ai/DeepSeek-R1-0528on 4xB200 and 8xB200 respectively, usingCutlass-MLAandFlashInfer-MLAbackends. Running MTP with FP4 on B200 requires the fix in #25987.Known issues
TheCutlass-MLAbackend produces incorrect output when using speculative decoding. It is not clear to my why this happens, I have debugged with enforce-eager but did not identify any issues except incorrect model output. I have not verified if this also occurs on H200, but I believeFLASH_ATTN_MLAis also an option on Hopper so it may be sufficient to deprecateCutlass-MLAwhen speculative decoding is enabled.See #26042 for tracking on this correctness issue, which seems to indicate the root cause is MLA chunked prefill.
The fix is in #26063. I will rerun the experiments for a better baseline, but the correctness of this branch for MTP is still valid.
Test Result
4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA MTP=3
4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA MTP=3
4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA No-Spec
4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA No-Spec
8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA MTP=3
8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA MTP=3
8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA No-Spec
8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA No-Spec