feat: Add TRTLLM-Gen Skip-Softmax kernels for prefill and decode#2477
feat: Add TRTLLM-Gen Skip-Softmax kernels for prefill and decode#2477yzh119 merged 13 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Dom Brown <[email protected]>
Signed-off-by: Dom Brown <[email protected]>
📝 WalkthroughWalkthroughAdds optional softmax-skipping controls (bool and threshold scale factor) across FMHA: launcher, kernel params, kernel hash, and Python prefill/decode APIs; propagates values through paged attention paths, updates tests, and adjusts artifact metadata. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant App as Python API
participant Launcher as trtllm_fmha_kernel_launcher
participant Runner as FMHA Runner / Kernel Selector
participant Kernel as CUDA FMHA Kernel
App->>Launcher: call paged_attention_launch(..., skips_softmax, skip_softmax_threshold_scale_factor)
Launcher->>Runner: build RunnerParams (mSkipsSoftmaxWhenPossible, mSkipSoftmaxThresholdScaleFactor)
Runner->>Runner: compute kernel selection & hash (includes skipsSoftmax bit)
Runner->>Kernel: launch kernel with KernelParams (ptrSkipSoftmaxStats, mSkipSoftmaxThresholdScaleFactor)
Kernel-->>App: produce attention outputs
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @DomBrown, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance optimization by integrating NVIDIA's Skip-Softmax attention sparsity into the FlashInfer library, specifically for Blackwell GPUs. This feature aims to accelerate long-context inference by dynamically skipping computationally intensive softmax calculations based on a configurable threshold. The changes involve modifying low-level C++ kernels, updating Python API interfaces for prefill and decode operations, and expanding test coverage to validate the new functionality. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for Skip-Softmax attention for Blackwell GPUs in TRTLLM-Gen kernels, which is a valuable performance feature. The changes are well-integrated, with new optional parameters plumbed from the Python API down to the CUDA kernel launchers for both prefill and decode paths. The corresponding kernel selection logic and parameter structs have been updated correctly. The test suite has also been extended to cover the new skips_softmax functionality, including checks for unsupported backend and data type combinations. I have one suggestion to expand test coverage for long sequences with skip-softmax enabled to ensure its correctness in that key scenario.
PerkzZheng
left a comment
There was a problem hiding this comment.
The changes LGTM. thanks!
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/prefill.py (1)
2061-2290:⚠️ Potential issue | 🟡 MinorGuard skip-softmax against unsupported backends.
skip_softmax_threshold_scale_factoris accepted by the public API but ignored for non‑TRTLLM backends. Consider an explicit guard to avoid silent no‑ops.💡 Suggested guard
@@ if self._backend != "trtllm-gen": # NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function. # Remove this check if the backend supports dynamic window_left. assert window_left == self._window_left + if ( + skip_softmax_threshold_scale_factor is not None + and self._backend != "trtllm-gen" + ): + raise ValueError( + "skip_softmax_threshold_scale_factor is only supported for trtllm-gen backend." + )flashinfer/decode.py (2)
1196-1326:⚠️ Potential issue | 🟡 MinorValidate q_len_per_req and skip-softmax support early.
q_len_per_reqcan beNoneand is used in a reshape for TRTLLM-gen, which will throw a TypeError. Also,skip_softmax_threshold_scale_factoris silently ignored for non‑TRTLLM backends. Adding explicit guards makes the API safer and more predictable.💡 Suggested validation
@@ if enable_pdl is None: enable_pdl = device_support_pdl(q.device) + if ( + skip_softmax_threshold_scale_factor is not None + and self._backend != "trtllm-gen" + ): + raise ValueError( + "skip_softmax_threshold_scale_factor is only supported for trtllm-gen backend." + ) @@ - if self._backend == "trtllm-gen": + if self._backend == "trtllm-gen": + if q_len_per_req is None or q_len_per_req <= 0: + raise ValueError("q_len_per_req must be a positive int for trtllm-gen.") + if q.size(0) % q_len_per_req != 0: + raise ValueError("q.size(0) must be divisible by q_len_per_req.") q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
2119-2410:⚠️ Potential issue | 🟡 MinorDisallow skip-softmax on non‑TRTLLM backends.
skip_softmax_threshold_scale_factoris accepted but ignored whenbackend="xqa". Consider a guard so callers don’t assume it took effect.💡 Suggested guard
@@ if backend == "auto": backend = ( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) + if skip_softmax_threshold_scale_factor is not None and backend != "trtllm-gen": + raise ValueError( + "skip_softmax_threshold_scale_factor is only supported for trtllm-gen backend." + )
|
/bot run |
|
[FAILED] Pipeline #43282906: 1/20 passed |
|
Looks like some failures occurred. I will look tomorrow -- looks like I missed a parameter somewhere |
|
Should be fixed now. If you could trigger CI that would be great (I don't think I'm allowed yet) |
|
/bot run |
|
@flashinfer-bot run |
|
[FAILED] Pipeline #43419335: 15/20 passed |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
flashinfer/prefill.py (1)
2061-2077:⚠️ Potential issue | 🟡 MinorDocument new
sinksandskip_softmax_threshold_scale_factorparams inrun().
Line 2075–Line 2076 add public parameters, but the docstring below doesn’t describe them yet. Please update the parameter docs for user-facing clarity.flashinfer/decode.py (3)
1157-1256:⚠️ Potential issue | 🟡 MinorDocument the new
sinksparameter in the public docstring.
Line 1171 addssinksto the public signature, but the docstring’s Parameters section doesn’t mention it yet.
2123-2418:⚠️ Potential issue | 🟡 MinorValidate skip-softmax for unsupported backends.
Line 2146 adds a publicskip_softmax_threshold_scale_factor, but whenbackend == "xqa"it is ignored (Line 2271+). Please either validate that it isNonefor xqa or explicitly warn/raise to avoid silent no‑ops.Suggested guard
@@ - if backend == "xqa": + if backend == "xqa": + if skip_softmax_threshold_scale_factor is not None: + raise ValueError( + "skip_softmax_threshold_scale_factor is only supported by the trtllm-gen backend." + ) # xqa backend doesn't support nvfp4 output
1319-1330:⚠️ Potential issue | 🟠 MajorFix output shape mismatch when
q_len_per_req > 1in the TRTLLM-gen code path.Line 1328 reshapes
qto 4D, butoutis allocated as 3D (line 1323). Whenq_len_per_req > 1, the kernel receives a shape-mismatched output buffer. Reshapeoutto match the reshapedqbefore the kernel call, pass the 4D view to the kernel, and return the original 3D tensor.🐛 Proposed fix
if out is None: out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype out = torch.empty( q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device ) else: out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out") + out_view = out if self._backend == "trtllm-gen": + if q_len_per_req is None: + raise ValueError("q_len_per_req must be set for trtllm-gen decode.") + if q.size(0) % q_len_per_req != 0: + raise ValueError( + f"q.size(0) ({q.size(0)}) must be divisible by q_len_per_req ({q_len_per_req})." + ) + batch_size = q.size(0) // q_len_per_req q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2)) + out_view = out.view(batch_size, q_len_per_req, out.size(1), out.size(2)) if self.use_tensor_cores: run_args = [ ... - out, + out_view, ... ] else: ... run_args = [ ... - out, + out_view, ... ]
|
@yongwww can you take a look at failed UT? |
<!-- .github/pull_request_template.md --> ## 📌 Description This PR is a follow-up to #2477, expanding support to MLA. It also modifies the runner slightly to 'short-circuit' to normal attention kernels if threshold is zero, to reduce overhead. Tests updated to use a very tiny threshold instead, so we still get the same result as normal attention without triggering the fallback. ## 🔍 Related Issues #2306 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added optional skip_softmax_threshold_scale_factor and skips_softmax controls to ragged/paged attention and batch-decode flows to enable conditional softmax skipping. * **Documentation** * Documented the new parameter, its formula, accuracy/performance tradeoffs, and backend-specific validation rules. * **Chores** * Threaded parameter through public APIs and updated signatures; added runtime validation for incompatible backends. * **Tests** * Expanded attention tests to cover skip-softmax scenarios across DeepSeek, paged, and MLA paths. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR adds initial Blackwell support for Skip-Softmax attention sparsity. See:
The skip-softmax threshold is added as an optional parameter. If it is
Nonethen 'normal' attention is used, otherwise we will use skip-softmax with the specified threshold.I have added unit tests by parameterising the existing tests with a skip threshold of zero. This allows us to compare to a 'normal' attention reference kernel.
🔍 Related Issues
#2306 is closed by this PR.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
I'm new to FlashInfer so may have missed an API update somewhere. I will address this if any issues are found :)
Summary by CodeRabbit
New Features
Chores
Tests