-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Perf] Refactor cudagraph_support to enable full CUDA graphs for spec decoding with FlashInfer #28479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf] Refactor cudagraph_support to enable full CUDA graphs for spec decoding with FlashInfer #28479
Conversation
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
|
Documentation preview: https://vllm--28479.org.readthedocs.build/en/28479/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a well-designed refactoring to enable more flexible and dynamic CUDA graph support for attention backends. By making cudagraph_support a private member and introducing a new get_cudagraph_support method, the code now dynamically determines the CUDA graph capability on a per-backend, per-KV-group basis. This change is crucial for enabling full CUDA graph support for speculative decoding with FlashInfer on specific hardware like Blackwell. The updates to _check_and_update_cudagraph_mode and the corresponding documentation changes are clear and correct. Overall, this is a solid performance enhancement with clean implementation.
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Id like to work towards reverting #27427 (and move back to this being an instance property) in the future; but we need broader cudagraph refactors to get there
vadiklyutiy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before we use use_trtllm_attention for checking both prefill and decode.
Right now seem use_trtllm_attention is using for checking prefill only and can_use_trtllm_attention .
May we refactor:
- use proper name like
use_trtllm_prefill_attnanduse_trtllm_decode_attn - remove from
use_trtllm_attentionprocessing of decode case
Maybe it's worth to do in separate PR
vadiklyutiy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before we use use_trtllm_attention for checking both prefill and decode.
Right now seem use_trtllm_attention is using for checking prefill only and can_use_trtllm_attention .
May we refactor:
- use proper name like
use_trtllm_prefill_attnanduse_trtllm_decode_attn - remove from
use_trtllm_attentionprocessing of decode case
Maybe it's worth to do in separate PR
vadiklyutiy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more style
Is there some reason to hold [_]cudagraph_support and get_cudagraph_support in *MetaBuilder classes, maybe *Backend(AttentionBackend) is better place?
|
… decoding with FlashInfer (vllm-project#28479) Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: George D. Torres <[email protected]>
… decoding with FlashInfer (vllm-project#28479) Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: Bram Wasti <[email protected]>
… decoding with FlashInfer (vllm-project#28479) Signed-off-by: Benjamin Chislett <[email protected]>
Purpose
Revised implementation of #26937
This PR makes
_cudagraph_supporta private member and usesget_cudagraph_support(vllm_config, kv_cache_spec). Also updates_check_and_update_cudagraph_modeto consider support per-backend, per-kv-group.TRTLLM-gen kernels support full cuda graphs, but are only used with FlashInfer on Blackwell under certain conditions.
It might not be safe to change FlashInfer's cudagraph_support to UNIFORM_BATCH always, but we can still set it when we know TRTLLM-gen backend will be used.
Also update the docs to reflect the FlashInfer cuda graph compatibility, and fill in the missing entry for FlashInferMLA.
FIX #26856
Test Plan
See #26937 for functional correctness testing / benchmarking. Rerunning on this branch gives the same results.
Local test run passes for
tests/v1/attention.