Skip to content

feat: Enable TRTLLM-Gen Skip-Softmax attention for MLA#2547

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
DomBrown:dev/skip_softmax_mla
Feb 19, 2026
Merged

feat: Enable TRTLLM-Gen Skip-Softmax attention for MLA#2547
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
DomBrown:dev/skip_softmax_mla

Conversation

@DomBrown
Copy link
Contributor

@DomBrown DomBrown commented Feb 12, 2026

📌 Description

This PR is a follow-up to #2477, expanding support to MLA.
It also modifies the runner slightly to 'short-circuit' to normal attention kernels if threshold is zero, to reduce overhead. Tests updated to use a very tiny threshold instead, so we still get the same result as normal attention without triggering the fallback.

🔍 Related Issues

#2306

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added optional skip_softmax_threshold_scale_factor and skips_softmax controls to ragged/paged attention and batch-decode flows to enable conditional softmax skipping.
  • Documentation

    • Documented the new parameter, its formula, accuracy/performance tradeoffs, and backend-specific validation rules.
  • Chores

    • Threaded parameter through public APIs and updated signatures; added runtime validation for incompatible backends.
  • Tests

    • Expanded attention tests to cover skip-softmax scenarios across DeepSeek, paged, and MLA paths.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 12, 2026

📝 Walkthrough

Walkthrough

Adds optional skip-softmax controls (skip_softmax_threshold_scale_factor, skips_softmax) and threads them from Python APIs (prefill/MLA) through TRTLLM wrappers into the FMHA kernel launcher/runner; tests updated to exercise skip/non-skip cases and backend validation.

Changes

Cohort / File(s) Summary
FMHA Launcher Controls
csrc/trtllm_fmha_kernel_launcher.cu
Added parameters float skip_softmax_threshold_scale_factor and bool skips_softmax to trtllm_ragged_attention_launcher() and trtllm_ragged_attention() signatures; compute/propagate skip_softmax_threshold_scale_factor_value and skips_softmax, and assign to runner params (mSkipsSoftmaxWhenPossible, mSkipSoftmaxThresholdScaleFactor).
MLA Decode Integration
flashinfer/mla.py
Extended trtllm_batch_decode_with_kv_cache_mla() signature to accept skip_softmax_threshold_scale_factor: Optional[float], added docstring and backend validation (reject for XQA/sparse MLA), and passed the value into the trtllm-gen backend path.
Prefill / Paged Attention
flashinfer/prefill.py
Added optional skip_softmax_threshold_scale_factor to trtllm_ragged_attention_deepseek() and internal paged_run() wrapper, forwarding it into the paged attention/kernel call chain.
Tests — Deepseek / Paged
tests/attention/test_trtllm_gen_attention.py
Parametrized tests with skips_softmax, compute/pass skip_softmax_threshold_scale_factor (near-zero vs None) into prefill/deepseek and other wrapper calls; adjusted expectations/tolerances where applicable.
Tests — MLA Decode
tests/attention/test_trtllm_gen_mla.py
Added skips_softmax parameter to helpers/tests, derive skip_softmax_threshold_scale_factor and pass into trtllm_batch_decode_with_kv_cache_mla(); enforce backend compatibility for skipping behavior.
Build
CMakeLists.txt
Minor edits to reflect changed/added source exports (small lines added/removed).

Sequence Diagram(s)

sequenceDiagram
  participant Test as "Test"
  participant PythonAPI as "Python API\n(prefill / MLA)"
  participant TRTWrapper as "TRTLLM Wrapper\n(paged / ragged)"
  participant Launcher as "FMHA Launcher\n(csrc/..._launcher.cu)"
  participant Runner as "FMHA Runner / GPU"

  Test->>PythonAPI: call with skips_softmax / skip_softmax_threshold_scale_factor
  PythonAPI->>TRTWrapper: forward skip_softmax_threshold_scale_factor
  TRTWrapper->>Launcher: call trtllm_ragged_attention(..., skip_softmax_threshold_scale_factor, skips_softmax, ...)
  Launcher->>Runner: set runner_params.mSkipsSoftmaxWhenPossible / mSkipSoftmaxThresholdScaleFactor
  Runner->>Runner: select kernel path (skip softmax or full)
  Runner-->>Launcher: results
  Launcher-->>TRTWrapper: results
  TRTWrapper-->>PythonAPI: results
  PythonAPI-->>Test: return outputs
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • nvmbreughe
  • cyx-6
  • aleozlx
  • djmmoss
  • jiahanc
  • yzh119
  • bkryu

Poem

🐰 I tunneled through code with a tiny cheer,

nudged thresholds close so softmax may veer.
From Python to CUDA the values hop,
kernels decide — skip or full stop.
A little rabbit clap for faster ops! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (23 files):

⚔️ benchmarks/routines/flashinfer_benchmark_utils.py (content)
⚔️ benchmarks/routines/gemm.py (content)
⚔️ csrc/flashinfer_sampling_binding.cu (content)
⚔️ csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (content)
⚔️ csrc/sampling.cu (content)
⚔️ csrc/trtllm_fmha_kernel_launcher.cu (content)
⚔️ flashinfer/__init__.py (content)
⚔️ flashinfer/aot.py (content)
⚔️ flashinfer/gemm/__init__.py (content)
⚔️ flashinfer/gemm/gemm_base.py (content)
⚔️ flashinfer/jit/gemm/__init__.py (content)
⚔️ flashinfer/jit/gemm/core.py (content)
⚔️ flashinfer/mla.py (content)
⚔️ flashinfer/prefill.py (content)
⚔️ flashinfer/sampling.py (content)
⚔️ flashinfer/triton/__init__.py (content)
⚔️ flashinfer/utils.py (content)
⚔️ include/flashinfer/sampling.cuh (content)
⚔️ scripts/task_run_unit_tests.sh (content)
⚔️ scripts/test_utils.sh (content)
⚔️ tests/attention/test_trtllm_gen_attention.py (content)
⚔️ tests/attention/test_trtllm_gen_mla.py (content)
⚔️ tests/gemm/test_bmm_fp8.py (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main feature being added: enabling TRTLLM-Gen Skip-Softmax attention support for MLA, which aligns with the primary changes across multiple files.
Description check ✅ Passed The description provides context about being a follow-up to #2477, mentions the MLA expansion, explains the optimization strategy, and indicates tests are updated. All required checklist items are marked complete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
⚔️ Resolve merge conflicts (beta)
  • Auto-commit resolved conflicts to branch dev/skip_softmax_mla
  • Post resolved changes as copyable diffs in a comment

No actionable comments were generated in the recent review. 🎉

🧹 Recent nitpick comments
csrc/trtllm_fmha_kernel_launcher.cu (1)

302-305: Consider extracting the repeated skip-softmax derivation into a small helper.

The same pattern—value_or(0.0f) followed by != 0.0f—appears in three places (trtllm_paged_attention_decode, trtllm_paged_attention_context, trtllm_ragged_attention). A tiny inline helper or a std::pair-returning function would reduce duplication.

♻️ Example helper
// At namespace scope or as a file-local helper:
inline std::pair<float, bool> resolve_skip_softmax(Optional<float> threshold) {
  float value = threshold.value_or(0.0f);
  return {value, value != 0.0f};
}

Then at each call site:

-  float const skip_softmax_threshold_scale_factor_value =
-      skip_softmax_threshold_scale_factor.value_or(0.0f);
-  bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;
+  auto [skip_softmax_threshold_scale_factor_value, skips_softmax] =
+      resolve_skip_softmax(skip_softmax_threshold_scale_factor);

Also applies to: 392-395, 570-574

tests/attention/test_trtllm_gen_attention.py (1)

1134-1134: Consider reducing the skips_softmax parametrization for the largest test matrices.

Adding skips_softmax=[False, True] doubles the already-large combinatorial space of test_trtllm_batch_decode (17 batch configs × 2 backends × 10 dtype combos × …). The long-sequence test wisely restricts to [False] only (line 1339). Consider a similar approach here — perhaps testing skips_softmax=True only on a representative subset of dtype/batch configs rather than the full cross product, to keep CI time manageable.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@DomBrown DomBrown marked this pull request as ready for review February 12, 2026 17:36
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @DomBrown, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance capabilities of the attention mechanisms by integrating and expanding support for TRTLLM-Gen's skip-softmax feature. By introducing a configurable threshold, the system can now intelligently bypass certain softmax computations, leading to more efficient processing for both MLA and DeepSeek attention models. This change provides users with greater control over performance-accuracy trade-offs in their attention operations.

Highlights

  • Skip-Softmax Attention for MLA: Enabled TRTLLM-Gen Skip-Softmax attention for the MLA (Multi-Layer Attention) backend, allowing for performance improvements by conditionally skipping softmax calculations.
  • Configurable Skip-Softmax Threshold: Introduced an optional skip_softmax_threshold_scale_factor parameter to control the sparsity of softmax operations in both batch decode with KV cache and ragged attention operations.
  • DeepSeek Attention Support: Extended the DeepSeek attention path to support configurable skip-softmax functionality.
  • Comprehensive Testing: Added and updated attention tests to validate the new softmax skipping functionality across both DeepSeek and MLA backends, ensuring correctness and compatibility.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fmha_kernel_launcher.cu
    • Added skip_softmax_threshold_scale_factor and skips_softmax parameters to the trtllm_ragged_attention_launcher function signature.
    • Passed the new skip-softmax parameters to the internal runner parameters for attention execution.
    • Modified the trtllm_ragged_attention function to accept an optional skip_softmax_threshold_scale_factor and derive the skips_softmax boolean.
  • flashinfer/mla.py
    • Added skip_softmax_threshold_scale_factor as an optional parameter to the trtllm_batch_decode_with_kv_cache_mla function.
    • Updated the docstring for trtllm_batch_decode_with_kv_cache_mla to describe the new skip-softmax parameter and its implications.
    • Implemented validation to raise ValueError if skip-softmax is attempted with the XQA backend or sparse MLA.
  • flashinfer/prefill.py
    • Added skip_softmax_threshold_scale_factor as an optional parameter to the trtllm_ragged_attention_deepseek function.
    • Updated the docstring for trtllm_ragged_attention_deepseek to explain the new skip-softmax parameter.
    • Passed the skip_softmax_threshold_scale_factor to the underlying trtllm_ragged_attention launcher call.
  • tests/attention/test_trtllm_gen_attention.py
    • Added a skips_softmax parameter to the test_trtllm_gen_prefill_deepseek and test_trtllm_gen_prefill_deepseek_bs1 test functions.
    • Conditionally set skip_softmax_threshold_scale_factor based on the skips_softmax parameter for testing purposes.
    • Passed the skip_softmax_threshold_scale_factor to the flashinfer.prefill.trtllm_ragged_attention_deepseek call within the tests.
  • tests/attention/test_trtllm_gen_mla.py
    • Added a skips_softmax parameter to the trtllm_batch_decode_mla helper function and its associated test parametrizations.
    • Introduced a pytest.skip condition if skips_softmax is enabled for backends other than trtllm-gen.
    • Conditionally set skip_softmax_threshold_scale_factor based on the skips_softmax parameter for testing.
    • Passed the skip_softmax_threshold_scale_factor to the flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla call within the tests.
Activity
  • The pull request was created to expand skip-softmax support to MLA, following up on a previous PR (feat: Add TRTLLM-Gen Skip-Softmax kernels for prefill and decode #2477).
  • The changes introduce new parameters and logic to enable and control skip-softmax functionality in attention operations.
  • New tests were added to validate the correct behavior of the skip-softmax feature across different attention backends.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/mla.py (1)

521-698: ⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold_scale_factor for non‑TRTLLM‑GEN backends.

When backend resolves to "xqa", the new parameter is silently ignored. Consider rejecting non‑None values unless backend == "trtllm-gen" to avoid misleading callers.

🛡️ Proposed guard
     if backend == "xqa":
+        if skip_softmax_threshold_scale_factor is not None:
+            raise ValueError(
+                "skip_softmax_threshold_scale_factor is only supported for trtllm-gen backend"
+            )
         if (
             get_compute_capability(query.device)[0] != 12
             or query.dtype != torch.float8_e4m3fn
             or kv_cache.dtype != torch.float8_e4m3fn
         ):

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for skip-softmax attention for MLA and DeepSeek paths by adding a new skip_softmax_threshold_scale_factor parameter. The changes are well-structured, propagating the new parameter through both the Python and C++ layers correctly. The tests have also been updated to validate the new functionality by checking that a zero threshold yields results consistent with the standard attention mechanism. My main feedback is to correct an invalid link in the docstrings that references a non-existent paper.

@bkryu
Copy link
Collaborator

bkryu commented Feb 12, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !314 has been created, and the CI pipeline #43909043 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43909043: 15/20 passed

@yzh119 yzh119 merged commit 11537c7 into flashinfer-ai:main Feb 19, 2026
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants