Skip to content

feat: Add TRTLLM-Gen Skip-Softmax kernels for prefill and decode#2477

Merged
yzh119 merged 13 commits intoflashinfer-ai:mainfrom
DomBrown:dev/skip_softmax_prefill
Feb 11, 2026
Merged

feat: Add TRTLLM-Gen Skip-Softmax kernels for prefill and decode#2477
yzh119 merged 13 commits intoflashinfer-ai:mainfrom
DomBrown:dev/skip_softmax_prefill

Conversation

@DomBrown
Copy link
Contributor

@DomBrown DomBrown commented Feb 3, 2026

📌 Description

This PR adds initial Blackwell support for Skip-Softmax attention sparsity. See:

The skip-softmax threshold is added as an optional parameter. If it is None then 'normal' attention is used, otherwise we will use skip-softmax with the specified threshold.

I have added unit tests by parameterising the existing tests with a skip threshold of zero. This allows us to compare to a 'normal' attention reference kernel.

🔍 Related Issues

#2306 is closed by this PR.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

I'm new to FlashInfer so may have missed an API update somewhere. I will address this if any issues are found :)

Summary by CodeRabbit

  • New Features

    • Added optional softmax-skipping controls (enable flag and threshold scale factor) across prefill and decode public APIs to allow selective skip-softmax behavior while preserving defaults.
  • Chores

    • Updated artifact repository paths and checksums for FMHA artifacts.
  • Tests

    • Extended and parameterized attention tests and helpers to cover and gate the new softmax-skipping paths.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 3, 2026

📝 Walkthrough

Walkthrough

Adds optional softmax-skipping controls (bool and threshold scale factor) across FMHA: launcher, kernel params, kernel hash, and Python prefill/decode APIs; propagates values through paged attention paths, updates tests, and adjusts artifact metadata.

Changes

Cohort / File(s) Summary
FMHA Launcher
csrc/trtllm_fmha_kernel_launcher.cu
Launcher signature extended to accept skip_softmax_threshold_scale_factor and skips_softmax; values propagated into FMHA runner setup and kernel launch parameters.
FMHA Kernel Selection & Params
include/flashinfer/trtllm/fmha/fmhaKernels.cuh, include/flashinfer/trtllm/fmha/fmhaRunnerParams.h, include/flashinfer/trtllm/fmha/kernelParams.h
Added runner fields mSkipsSoftmaxWhenPossible and mSkipSoftmaxThresholdScaleFactor; kernel hashID now encodes skipsSoftmax (new bit), replaced reserved kernel fields with skip-softmax fields, and propagated initializers.
Python Decode API
flashinfer/decode.py, flashinfer/mla.py
Public and internal decode/paged_run signatures updated to accept and thread skip_softmax_threshold_scale_factor (plus sinks, q_len_per_req) through to kernel call sites; MLA caller updated to pass the new arg.
Python Prefill / Context API
flashinfer/prefill.py
Prefill/context paged_run and wrapper run signatures extended to accept skip_softmax_threshold_scale_factor (and sinks); values forwarded to underlying trtllm-gen module and kernel paths.
Artifact Metadata
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_FMHA repository path and CheckSumHash.TRTLLM_GEN_FMHA SHA256 checksum.
Tests
tests/attention/test_trtllm_gen_attention.py
Added skips_softmax parameter to tests and helpers; compute/propagate skip_softmax_threshold_scale_factor where applicable and guard incompatible dtype/backends.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant App as Python API
    participant Launcher as trtllm_fmha_kernel_launcher
    participant Runner as FMHA Runner / Kernel Selector
    participant Kernel as CUDA FMHA Kernel

    App->>Launcher: call paged_attention_launch(..., skips_softmax, skip_softmax_threshold_scale_factor)
    Launcher->>Runner: build RunnerParams (mSkipsSoftmaxWhenPossible, mSkipSoftmaxThresholdScaleFactor)
    Runner->>Runner: compute kernel selection & hash (includes skipsSoftmax bit)
    Runner->>Kernel: launch kernel with KernelParams (ptrSkipSoftmaxStats, mSkipSoftmaxThresholdScaleFactor)
    Kernel-->>App: produce attention outputs
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • #2084 — Modifies the same TRTLLM launcher/decode/context function signatures with related kernel-control parameters.
  • #2138 — Changes FMHA launcher and runner-params interfaces; overlaps on adding kernel-selection fields.
  • #2254 — Also modifies trtllm_paged_attention_launcher signature to propagate additional runner metadata.

Suggested reviewers

  • joker-eph
  • aleozlx
  • cyx-6
  • wenscarl
  • nvmbreughe
  • yzh119
  • PerkzZheng

Poem

🐰 I hopped through code with a tiny flag,
Softmax may snooze at the threshold I wag.
Bits flipped in hashes, params tucked tight,
Kernels whisper softly through the night. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.53% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding Skip-Softmax kernels for prefill and decode in TRTLLM-Gen, which aligns with the primary focus of the changeset.
Description check ✅ Passed The description includes all required sections: a clear description of what the PR does with context (Skip-Softmax support with optional threshold parameter), related issue reference (#2306), and a completed checklist confirming pre-commit hooks and tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @DomBrown, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization by integrating NVIDIA's Skip-Softmax attention sparsity into the FlashInfer library, specifically for Blackwell GPUs. This feature aims to accelerate long-context inference by dynamically skipping computationally intensive softmax calculations based on a configurable threshold. The changes involve modifying low-level C++ kernels, updating Python API interfaces for prefill and decode operations, and expanding test coverage to validate the new functionality.

Highlights

  • Blackwell Skip-Softmax Integration: Implemented initial support for NVIDIA's Skip-Softmax attention sparsity, leveraging the Blackwell architecture for improved long-context inference.
  • New Parameter Introduction: Introduced skip_softmax_threshold_scale_factor as an optional parameter across C++ kernel launchers and Python API bindings, enabling conditional skipping of softmax operations.
  • Kernel Parameter Updates: Modified core kernel parameter structures (TllmGenFmhaRunnerParams, TllmGenSelectKernelParams, KernelParams) to accommodate the new skip-softmax logic and related statistics.
  • Expanded Test Coverage: Extended existing test suites for both prefill and decode operations to include scenarios with and without skip-softmax, ensuring functional correctness.
  • Artifact Updates: Updated artifact paths and checksums for TRTLLM_GEN_FMHA, reflecting the compilation of new kernel binaries that incorporate the skip-softmax feature.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fmha_kernel_launcher.cu
    • Added skip_softmax_threshold_scale_factor (float) and skips_softmax (bool) parameters to trtllm_paged_attention_launcher.
    • Updated trtllm_paged_attention_decode and trtllm_paged_attention_context function signatures to accept an optional skip_softmax_threshold_scale_factor.
    • Passed the new skip-softmax parameters to the fmha_runner and trtllm_paged_attention_launcher calls.
  • flashinfer/artifacts.py
    • Updated TRTLLM_GEN_FMHA artifact path and checksum, indicating new kernel binaries.
  • flashinfer/decode.py
    • Added skip_softmax_threshold_scale_factor as an optional float parameter to _paged_run and trtllm_batch_decode_with_kv_cache functions.
    • Passed this new parameter to the underlying C++ launcher.
    • Updated docstrings to describe the new parameter.
  • flashinfer/prefill.py
    • Added skip_softmax_threshold_scale_factor as an optional float parameter to _paged_run, paged_run, and run functions.
    • Passed this new parameter to the underlying C++ launcher.
    • Updated docstrings to describe the new parameter.
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
    • Modified hashID function to include skipsSoftmax as a parameter for kernel identification.
    • Updated the bit allocation for hashID to include skipsSoftmax.
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
    • Added mSkipsSoftmaxWhenPossible (bool) and mSkipSoftmaxThresholdScaleFactor (float) to TllmGenFmhaRunnerParams.
    • Added mSkipsSoftmaxWhenPossible (bool) to TllmGenSelectKernelParams and initialized it from TllmGenFmhaRunnerParams.
  • include/flashinfer/trtllm/fmha/kernelParams.h
    • Replaced ptrReservedMem with ptrSkipSoftmaxStats (int32_t array of size 4 for statistics).
    • Replaced mReservedParam with mSkipSoftmaxThresholdScaleFactor (float).
    • Updated KernelParams constructor to set mSkipSoftmaxThresholdScaleFactor.
  • tests/attention/test_trtllm_gen_attention.py
    • Added skips_softmax parameter to _test_trtllm_batch_prefill and _test_trtllm_batch_decode helper functions.
    • Extended various test functions with pytest.mark.parametrize('skips_softmax', [False, True]).
    • Introduced conditional skips for skips_softmax when q_dtype != kv_dtype or when using the xqa backend.
    • Set skip_softmax_threshold_scale_factor to 0.0 when skips_softmax is enabled in tests, for validation against normal attention.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Skip-Softmax attention for Blackwell GPUs in TRTLLM-Gen kernels, which is a valuable performance feature. The changes are well-integrated, with new optional parameters plumbed from the Python API down to the CUDA kernel launchers for both prefill and decode paths. The corresponding kernel selection logic and parameter structs have been updated correctly. The test suite has also been extended to cover the new skips_softmax functionality, including checks for unsupported backend and data type combinations. I have one suggestion to expand test coverage for long sequences with skip-softmax enabled to ensure its correctness in that key scenario.

Copy link
Contributor

@PerkzZheng PerkzZheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes LGTM. thanks!

@DomBrown DomBrown marked this pull request as ready for review February 3, 2026 16:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/prefill.py (1)

2061-2290: ⚠️ Potential issue | 🟡 Minor

Guard skip-softmax against unsupported backends.

skip_softmax_threshold_scale_factor is accepted by the public API but ignored for non‑TRTLLM backends. Consider an explicit guard to avoid silent no‑ops.

💡 Suggested guard
@@
         if self._backend != "trtllm-gen":
             # NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function.
             # Remove this check if the backend supports dynamic window_left.
             assert window_left == self._window_left
+        if (
+            skip_softmax_threshold_scale_factor is not None
+            and self._backend != "trtllm-gen"
+        ):
+            raise ValueError(
+                "skip_softmax_threshold_scale_factor is only supported for trtllm-gen backend."
+            )
flashinfer/decode.py (2)

1196-1326: ⚠️ Potential issue | 🟡 Minor

Validate q_len_per_req and skip-softmax support early.

q_len_per_req can be None and is used in a reshape for TRTLLM-gen, which will throw a TypeError. Also, skip_softmax_threshold_scale_factor is silently ignored for non‑TRTLLM backends. Adding explicit guards makes the API safer and more predictable.

💡 Suggested validation
@@
         if enable_pdl is None:
             enable_pdl = device_support_pdl(q.device)
+        if (
+            skip_softmax_threshold_scale_factor is not None
+            and self._backend != "trtllm-gen"
+        ):
+            raise ValueError(
+                "skip_softmax_threshold_scale_factor is only supported for trtllm-gen backend."
+            )
@@
-        if self._backend == "trtllm-gen":
+        if self._backend == "trtllm-gen":
+            if q_len_per_req is None or q_len_per_req <= 0:
+                raise ValueError("q_len_per_req must be a positive int for trtllm-gen.")
+            if q.size(0) % q_len_per_req != 0:
+                raise ValueError("q.size(0) must be divisible by q_len_per_req.")
             q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))

2119-2410: ⚠️ Potential issue | 🟡 Minor

Disallow skip-softmax on non‑TRTLLM backends.

skip_softmax_threshold_scale_factor is accepted but ignored when backend="xqa". Consider a guard so callers don’t assume it took effect.

💡 Suggested guard
@@
     if backend == "auto":
         backend = (
             "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
         )
+    if skip_softmax_threshold_scale_factor is not None and backend != "trtllm-gen":
+        raise ValueError(
+            "skip_softmax_threshold_scale_factor is only supported for trtllm-gen backend."
+        )

@bkryu
Copy link
Collaborator

bkryu commented Feb 4, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !295 has been created, and the CI pipeline #43282906 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43282906: 1/20 passed

@DomBrown
Copy link
Contributor Author

DomBrown commented Feb 4, 2026

Looks like some failures occurred. I will look tomorrow -- looks like I missed a parameter somewhere

@DomBrown
Copy link
Contributor Author

DomBrown commented Feb 5, 2026

Should be fixed now. If you could trigger CI that would be great (I don't think I'm allowed yet)

@yzh119
Copy link
Collaborator

yzh119 commented Feb 6, 2026

/bot run

@yzh119
Copy link
Collaborator

yzh119 commented Feb 6, 2026

@flashinfer-bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !295 has been updated with latest changes, and the CI pipeline #43419335 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43419335: 15/20 passed

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @DomBrown. Confirmed that the CI failures are unrelated.

Left a minor comment about docstrings.

@DomBrown DomBrown requested a review from bkryu February 9, 2026 19:40
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
flashinfer/prefill.py (1)

2061-2077: ⚠️ Potential issue | 🟡 Minor

Document new sinks and skip_softmax_threshold_scale_factor params in run().
Line 2075–Line 2076 add public parameters, but the docstring below doesn’t describe them yet. Please update the parameter docs for user-facing clarity.

flashinfer/decode.py (3)

1157-1256: ⚠️ Potential issue | 🟡 Minor

Document the new sinks parameter in the public docstring.
Line 1171 adds sinks to the public signature, but the docstring’s Parameters section doesn’t mention it yet.


2123-2418: ⚠️ Potential issue | 🟡 Minor

Validate skip-softmax for unsupported backends.
Line 2146 adds a public skip_softmax_threshold_scale_factor, but when backend == "xqa" it is ignored (Line 2271+). Please either validate that it is None for xqa or explicitly warn/raise to avoid silent no‑ops.

Suggested guard
@@
-    if backend == "xqa":
+    if backend == "xqa":
+        if skip_softmax_threshold_scale_factor is not None:
+            raise ValueError(
+                "skip_softmax_threshold_scale_factor is only supported by the trtllm-gen backend."
+            )
         # xqa backend doesn't support nvfp4 output

1319-1330: ⚠️ Potential issue | 🟠 Major

Fix output shape mismatch when q_len_per_req > 1 in the TRTLLM-gen code path.

Line 1328 reshapes q to 4D, but out is allocated as 3D (line 1323). When q_len_per_req > 1, the kernel receives a shape-mismatched output buffer. Reshape out to match the reshaped q before the kernel call, pass the 4D view to the kernel, and return the original 3D tensor.

🐛 Proposed fix
        if out is None:
            out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
            out = torch.empty(
                q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
            )
        else:
            out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
            check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")

+       out_view = out
        if self._backend == "trtllm-gen":
+           if q_len_per_req is None:
+               raise ValueError("q_len_per_req must be set for trtllm-gen decode.")
+           if q.size(0) % q_len_per_req != 0:
+               raise ValueError(
+                   f"q.size(0) ({q.size(0)}) must be divisible by q_len_per_req ({q_len_per_req})."
+               )
+           batch_size = q.size(0) // q_len_per_req
            q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
+           out_view = out.view(batch_size, q_len_per_req, out.size(1), out.size(2))

        if self.use_tensor_cores:
            run_args = [
                ...
-               out,
+               out_view,
                ...
            ]
        else:
            ...
            run_args = [
                ...
-               out,
+               out_view,
                ...
            ]

@yzh119
Copy link
Collaborator

yzh119 commented Feb 9, 2026

@yongwww can you take a look at failed UT?

@yongwww
Copy link
Member

yongwww commented Feb 11, 2026

@yongwww can you take a look at failed UT?

pr #2500 broke the workflow, we might need to revert that change for now: #2524

@yzh119 yzh119 merged commit 09b3825 into flashinfer-ai:main Feb 11, 2026
15 of 17 checks passed
@DomBrown DomBrown deleted the dev/skip_softmax_prefill branch February 12, 2026 13:27
yzh119 pushed a commit that referenced this pull request Feb 19, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

This PR is a follow-up to #2477, expanding support to MLA.
It also modifies the runner slightly to 'short-circuit' to normal
attention kernels if threshold is zero, to reduce overhead. Tests
updated to use a very tiny threshold instead, so we still get the same
result as normal attention without triggering the fallback.

## 🔍 Related Issues

#2306 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added optional skip_softmax_threshold_scale_factor and skips_softmax
controls to ragged/paged attention and batch-decode flows to enable
conditional softmax skipping.

* **Documentation**
* Documented the new parameter, its formula, accuracy/performance
tradeoffs, and backend-specific validation rules.

* **Chores**
* Threaded parameter through public APIs and updated signatures; added
runtime validation for incompatible backends.

* **Tests**
* Expanded attention tests to cover skip-softmax scenarios across
DeepSeek, paged, and MLA paths.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants