Skip to content

Ameyn/gdn decode cutedsl kernel#2498

Merged
yzh119 merged 12 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn-decode-cutedsl-kernel
Feb 17, 2026
Merged

Ameyn/gdn decode cutedsl kernel#2498
yzh119 merged 12 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn-decode-cutedsl-kernel

Conversation

@ameynaik-hub
Copy link
Contributor

@ameynaik-hub ameynaik-hub commented Feb 5, 2026

Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.

Key features:

  • H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension
  • L2-normalized Q/K with configurable scale
  • Gated exponential decay via softplus
  • Delta rule updates: v_delta = beta * (v - pred)
  • Async H memory loading with aggressive pipelining
  • BF16 tensors with FP32 compute

Also includes:

  • benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf
  • Updated init.py exports

📌 Description

Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Summary by CodeRabbit

  • New Features

    • Added CUDA-accelerated Gated Delta Rule kernel with optimized paths for sequence lengths 1–4 and new public API entries gated_delta_rule and GatedDeltaRuleKernel (exported when available).
  • Benchmarks

    • New benchmark for the gated_delta_rule kernel and integration of an improved CuTe-DSL variant into the benchmark suite and CLI, with per‑T summaries and performance reporting.
  • Tests

    • End-to-end tests exercising the improved kernel for T=1–4, availability guards, and updated reference tests to support explicit state storage dtype handling.

Implements high-performance Gated Delta Rule linear attention kernel supporting
fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.

Key features:
- H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension
- Unified kernel architecture: T=2/3/4 share a single compile-time specialized
  kernel via Constexpr dispatch; T=1 uses separate kernel with persistent K optimization
- L2-normalized Q/K with configurable scale
- Gated exponential decay via softplus
- Delta rule updates: v_delta = beta * (v - pred)
- Bank-conflict-free cross-warp reductions
- Async H memory loading with aggressive pipelining
- BF16 tensors with FP32 compute for numerical stability
- GQA (grouped-query attention) support

Also includes:
- benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf
- Updated __init__.py exports

Signed-off-by: Amey Naik <[email protected]>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized CUDA kernel for the Gated Delta Rule linear attention mechanism, specifically tailored for decode-phase inference. The implementation, built with NVIDIA CuTe-DSL, provides high performance for fixed sequence lengths of 1, 2, 3, and 4. It incorporates advanced optimizations such as a K-last H state layout, L2-normalized Q/K, gated exponential decay, delta rule updates, and aggressive asynchronous memory pipelining, all while maintaining numerical stability with BF16 tensors and FP32 compute. A new benchmark script is also included to evaluate the kernel's performance.

Highlights

  • New Gated Delta Rule Kernel: Implements a high-performance Gated Delta Rule linear attention kernel for decode-phase inference.
  • Sequence Length Support: Supports fixed sequence lengths T=1, T=2, T=3, T=4 with specialized optimizations.
  • CuTe-DSL Implementation: Leverages NVIDIA CuTe-DSL for optimized CUDA kernel development, ensuring high performance.
  • Optimized H State Layout: Uses a K-last H state layout [B, HV, V, K] where K is the contiguous (fastest) dimension.
  • Q/K Normalization & Gating: Features L2-normalized Q/K with configurable scale and gated exponential decay via softplus.
  • Delta Rule Updates: Incorporates delta rule updates: v_delta = beta * (v - pred).
  • Memory & Pipelining: Utilizes asynchronous H memory loading with aggressive pipelining for improved efficiency.
  • Mixed Precision Compute: Employs BF16 tensors with FP32 compute for numerical stability.
  • Benchmarking Script: Includes a new benchmark script (benchmark_gated_delta_rule.py) for measuring kernel performance.
  • API Export: Exports the new gated_delta_rule function and GatedDeltaRuleKernel class in flashinfer/cute_dsl/__init__.py.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/cute_dsl/init.py
    • Added exports for gated_delta_rule and GatedDeltaRuleKernel to make the new functionality accessible.
  • flashinfer/cute_dsl/benchmark_gated_delta_rule.py
    • Introduced a new Python script for benchmarking the performance of the Gated Delta Rule CuTe-DSL kernel across various batch sizes and sequence lengths.
  • flashinfer/cute_dsl/gated_delta_rule.py
    • Implemented the core Gated Delta Rule linear attention kernel.
    • Includes shared helper functions for common operations.
    • Features a specialized kernel for seqlen=1 with persistent K optimization and aggressive pipelining.
    • Provides a unified kernel for seqlen=2, 3, 4 using compile-time Constexpr specialization.
Activity
  • The author has completed all pre-commit checks, ensuring code quality and style.
  • Tests have been added or updated, and all are passing, indicating the new functionality is working as expected and existing functionality is not regressed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a CuTe-DSL Gated Delta Rule CUDA kernel (seq_len 1–4) with a Python launcher/class, a standalone benchmark, benchmark integration, reference dtype/state handling updates, and tests exercising the new kernel and BF16 state paths.

Changes

Cohort / File(s) Summary
Public API Exports
flashinfer/cute_dsl/__init__.py
Export gated_delta_rule and GatedDeltaRuleKernel under the CuTe-DSL availability guard and update __all__.
Gated Delta Rule Core Implementation
flashinfer/cute_dsl/gated_delta_rule.py
New CuTe-DSL implementation: compile/select kernels for T∈{1,2,3,4}, BF16-storage/FP32-compute, GQA/per-token gating, async loads, bank-conflict-free reductions, kernel caching, Python API gated_delta_rule(...), GatedDeltaRuleKernel, and multiple launch entry points.
Kernel Benchmark
flashinfer/cute_dsl/benchmark_gated_delta_rule.py
New benchmarking script: L2 cache handling, CUDA-event timing, warmup, input synthesis, multi-config runs and formatted reporting.
Benchmark Integration
benchmarks/bench_gdn_decode.py
Integrates “Improved CuTe-DSL” path: wrapper and bench runner, CLI --version improved_cutedsl, T=1..4 support, availability guards, and updated result tables / speedup metrics.
Reference Implementations / dtype handling
tests/gdn/reference_delta_rule.py
Introduce state_dtype parameter; perform compute in FP32 and store states in state_dtype; update signatures, casts, and docstrings across delta/GDN helpers.
Tests for Improved CuTe-DSL
tests/gdn/test_decode_delta_rule.py
Add import/availability flag for improved CuTe-DSL, new test helper and parametrized tests for T=1..4 using BF16 state, conditional execution when kernel unavailable, and test harness updates.

Sequence Diagram(s)

sequenceDiagram
    participant PythonAPI as Python API
    participant KernelCache as Kernel Cache
    participant CuTe as CuTe Compiler
    participant GPU as GPU Kernel
    participant Memory as GPU Memory

    PythonAPI->>PythonAPI: validate inputs (q,k,v,A_log,a,dt_bias,seq_len)
    PythonAPI->>KernelCache: lookup compiled kernel (T, dtypes)
    alt cached
        KernelCache-->>PythonAPI: return cached kernel
    else not cached
        PythonAPI->>CuTe: compile kernel for (T, dtypes)
        CuTe-->>KernelCache: store compiled kernel
        KernelCache-->>PythonAPI: return compiled kernel
    end
    PythonAPI->>GPU: launch kernel on CUDA stream with args
    GPU->>Memory: async load Q,K,V,gates,state blocks
    GPU->>GPU: l2-normalize, gated decay, delta-rule updates, reductions
    GPU->>Memory: write output [B,T,HV,V] and updated state
    GPU-->>PythonAPI: signal completion / return tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

v0.6.2

Suggested reviewers

  • yzh119
  • cyx-6
  • nvmbreughe
  • bkryu
  • jimmyzho
  • jiahanc

Poem

🐇 I hop through kernels, compile with cheer,

Four seq-paths lined up, BF16 held dear.
Async loads whisper, reductions take flight,
Gates and deltas hum through CUDA night.
A tiny rabbit's clap — benchmarks gleam bright.

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title is vague and lacks specificity about the main change; it uses abbreviations (Ameyn, gdn, cutedsl) without context and does not clearly convey the primary objective. Clarify the title to describe the main change in plain language, e.g., 'Add high-performance Gated Delta Rule kernel for linear attention decoding' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description includes relevant technical details about the kernel implementation and marks pre-commit checks and tests as complete, meeting most template requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance Gated Delta Rule linear attention kernel using CuTe-DSL, supporting sequence lengths from 1 to 4. The implementation includes separate, optimized kernels for T=1 and a unified kernel for T=2,3,4, along with a benchmark script. My review focuses on the new kernel implementation in gated_delta_rule.py. I've identified a few areas for improvement: the seqlen=1 kernel contains significant code duplication that could be refactored for better maintainability. The kernel caching strategy could lead to performance issues with dynamic batch sizes due to unnecessary recompilations. Finally, the exported GatedDeltaRuleKernel class appears to be incomplete or unused. Addressing these points will improve the robustness and performance of the new kernel.

Comment on lines +1647 to +1679
class GatedDeltaRuleKernel:
"""
Gated Delta Rule Kernel for linear attention decode.

This kernel implements the Gated Delta Rule mechanism supporting sequence
lengths T=1, T=2, T=3, T=4 with optimized CUDA implementations.

Key features:
- T=1: Persistent K in registers with aggressive pipelining
- T=2/3/4: Unified kernel with compile-time Constexpr specialization
- L2-normalized Q/K with configurable scale
- Gated exponential decay via softplus
- Bank-conflict-free cross-warp reductions
- Async H memory loading

Args:
seq_len: Sequence length (1, 2, 3, or 4)
"""

def __init__(self, seq_len: int):
assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}"
self.seq_len = seq_len
self._compiled_kernel = None

def _get_launch_fn(self):
if self.seq_len == 1:
return gated_delta_rule_launch_seqlen1
elif self.seq_len == 2:
return gated_delta_rule_launch_seqlen2
elif self.seq_len == 3:
return gated_delta_rule_launch_seqlen3
else:
return gated_delta_rule_launch_seqlen4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The GatedDeltaRuleKernel class is defined and exported as part of the public API, but it appears to be incomplete and is not used by the main gated_delta_rule function. The _compiled_kernel member is initialized to None and is never assigned a compiled kernel, and there is no method to execute it. The functional entry point gated_delta_rule implements its own caching and launch logic.

If this class is intended for future use, it should be fully implemented. If it is obsolete or a work-in-progress, it should either be completed or removed to avoid confusing users of the library.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1773-1774: The cache_key currently set to (T, B) is too narrow and
can cause incorrect kernel reuse; update the cache key creation (variable
cache_key) to include the relevant tensor dimensions such as H, HV, K, and V
(the head count and per-head dims) so it uniquely identifies tensor shapes used
by the kernel—derive these from the involved tensors (the q/k/v/proj shapes or
whatever variables represent heads and head sizes) and include them alongside T
and B when building cache_key to prevent shape-mismatch cache hits.
- Around line 1666-1680: GatedDeltaRuleKernel currently only stores seq_len and
_compiled_kernel and provides _get_launch_fn but is never used by
gated_delta_rule or _compiled_kernels; either finish the class by adding a
public execution API (e.g., an __call__ or execute method that compiles/looks up
the kernel into _compiled_kernel, uses _get_launch_fn to obtain the launch
function and runs it with the same signature as the module-level implementation)
or remove the class from the public surface and migrate any kernel-caching logic
into the existing module-level _compiled_kernels paths; reference the class name
GatedDeltaRuleKernel, its attribute _compiled_kernel, method _get_launch_fn, and
the public gated_delta_rule/_compiled_kernels cache when making the change.
- Around line 1746-1749: The code accesses q, k, v, b and initial_state_source
without null checks even though their annotations allow None; add explicit
validation at the start of the function (e.g., in the function containing the
lines that reference q.shape and v.shape) to raise a clear ValueError if any of
q, k, v, b, or initial_state_source is None (or alternatively update the
function signature to remove Optional typing for these parameters), and ensure
any downstream use of from_dlpack(initial_state_source, ...) only occurs after
confirming initial_state_source is not None; reference the variables q, k, v, b
and initial_state_source and the shapes accessed (q.shape, v.shape[2],
v.shape[3]) so the checks are placed before those accesses.
- Around line 1153-1165: Unconditional shared-memory allocations q_sh2, k_sh2,
q_sh3, k_sh3 and the extra v_sh* buffers are always created even when NUM_TOKENS
is smaller; wrap those smem.allocate_tensor(...) calls in compile-time guards so
unused buffers are eliminated. Specifically, change the allocations for
q_sh2/k_sh2 to be inside an if constexpr (NUM_TOKENS >= 3) block and the
allocations for q_sh3/k_sh3 to be inside an if constexpr (NUM_TOKENS == 4) block
(and similarly guard v_sh2/v_sh3 as appropriate), keeping the same
smem.allocate_tensor(cutlass.Float32, 128) calls and names so later code (gate
computations) still references the same identifiers when compiled in. Ensure the
guards use the compile-time NUM_TOKENS parameter so the compiler can drop unused
allocations.
🧹 Nitpick comments (5)
flashinfer/cute_dsl/gated_delta_rule.py (3)

118-136: Potential numerical stability concern in compute_single_gate.

The softplus implementation handles large positive values but not large negative values of beta_x. When beta_x is very negative, cute.math.exp(beta_x) approaches zero, which is fine numerically. However, the sigmoid computation at line 134 could have issues with large positive beta_raw values where cute.math.exp(-beta_raw) underflows to 0, resulting in beta = 1.0 (acceptable), but large negative beta_raw causes exp(-beta_raw) to overflow.

Consider adding a symmetric threshold check for the sigmoid similar to the softplus handling.

💡 Optional: Add numerical safeguard for sigmoid
     g = -cute.math.exp(A_log_val) * softplus_x
     g_exp = cute.math.exp(g)
-    beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw))
+    # Numerically stable sigmoid
+    if beta_raw >= cutlass.Float32(0.0):
+        beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw))
+    else:
+        exp_beta = cute.math.exp(beta_raw)
+        beta = exp_beta / (cutlass.Float32(1.0) + exp_beta)
     return g_exp, beta

306-354: Consider removing unused o_head parameter.

The o_head parameter is passed but never used in process_first_token. The first token's output is returned and stored later during subsequent token processing. Removing it would clarify the function's contract.

♻️ Remove unused parameter
 `@cute.jit`
 def process_first_token(
     h_sh_chunk_curr,
     h_chunk,
     kq_chunk,
     k_sh,
     q_sh,
     v_sh,
     reduce_sh,
-    o_head,
     g_exp,
     beta,
     v_offset,
     pred_slot,
     warp_idx,
     lane_idx,
     k_base,
 ):

Note: This would require updating all call sites in process_vchunk_unified_234.


1746-1746: Unpacked variable H is unused.

The variable H is extracted from q.shape but never used in the function. Consider prefixing with underscore to indicate intentional non-use.

♻️ Prefix unused variable
-    B, T, H, K = q.shape
+    B, T, _H, K = q.shape
flashinfer/cute_dsl/benchmark_gated_delta_rule.py (2)

82-82: Import path may fail depending on execution context.

The relative import from gated_delta_rule import gated_delta_rule assumes the script is run from within the flashinfer/cute_dsl/ directory or that directory is in sys.path. This will fail if run from the repository root or via python -m.

Consider using absolute import for robustness.

♻️ Use absolute import
 def main():
-    from gated_delta_rule import gated_delta_rule
+    from flashinfer.cute_dsl.gated_delta_rule import gated_delta_rule

116-127: Loop variable capture in closure - latent bug risk.

The run_kernel function captures inputs and state by reference, not by value. While this works correctly here because the closure is consumed immediately via benchmark(), it's a latent bug that would manifest if the code were refactored (e.g., collecting closures to run later).

♻️ Bind loop variables explicitly
-            def run_kernel():
+            def run_kernel(inputs=inputs, state=state):
                 return gated_delta_rule(
                     A_log=inputs["A_log"],
                     a=inputs["a"],
                     dt_bias=inputs["dt_bias"],
                     q=inputs["q"],
                     k=inputs["k"],
                     v=inputs["v"],
                     b=inputs["b"],
                     initial_state_source=state,
                     scale=inputs["scale"],
                 )

Comment on lines +1153 to +1165
# Q/K buffers for token 2 (only for NUM_TOKENS >= 3)
q_sh2 = smem.allocate_tensor(cutlass.Float32, 128)
k_sh2 = smem.allocate_tensor(cutlass.Float32, 128)

# Q/K buffers for token 3 (only for NUM_TOKENS=4)
q_sh3 = smem.allocate_tensor(cutlass.Float32, 128)
k_sh3 = smem.allocate_tensor(cutlass.Float32, 128)

# V buffers
v_sh0 = smem.allocate_tensor(cutlass.Float32, 128)
v_sh1 = smem.allocate_tensor(cutlass.Float32, 128)
v_sh2 = smem.allocate_tensor(cutlass.Float32, 128)
v_sh3 = smem.allocate_tensor(cutlass.Float32, 128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cd /tmp && git ls-files flashinfer/cute_dsl/gated_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 138


🏁 Script executed:

wc -l flashinfer/cute_dsl/gated_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 112


🏁 Script executed:

sed -n '1140,1180p' flashinfer/cute_dsl/gated_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 1718


🏁 Script executed:

# Search for NUM_TOKENS usage and conditional patterns around this section
sed -n '1100,1200p' flashinfer/cute_dsl/gated_delta_rule.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3977


🏁 Script executed:

# Look for if statements guarding buffer allocation
rg -n "if.*NUM_TOKENS" flashinfer/cute_dsl/gated_delta_rule.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 357


🌐 Web query:

CuTe DSL Constexpr shared memory allocation dead code elimination

💡 Result:

In CuTe / CuTeDSL (CUTLASS), whether “shared-memory allocation” and the associated copies/loops disappear is mostly a question of whether the compiler can prove things at compile time (i.e., via types and constexpr / if constexpr) and whether the backend can safely apply dead-store / dead-code elimination across the generated code.

1) “constexpr shared memory allocation” in practice (CuTe/CUTLASS)

  • CuTe is designed so that tile shapes, layouts, and smem footprints are carried in the type system; when those are compile-time constants, the compiler can fully specialize code (unroll, constant-fold addresses, remove unused objects/paths).
  • Practically, that means: prefer CuTe’s compile-time control-flow (if constexpr, cute::static_for, type-based shapes/layouts) rather than runtime loops/conditions if you expect code (and sometimes even smem objects) to vanish.
  • CUTLASS’ own guidelines explicitly lean on if constexpr-based exhaustiveness and “unreachable” patterns to ensure unwanted branches don’t survive compilation (and to avoid “missing return” / unreachable-path issues) [1].

2) Why you can see “dead-looking” stores (local or shared) even at -O3

What you’re describing matches a real pattern people have observed with CuTeDSL-generated kernels: stores to .local (stack) that appear unused even though the values are consumed directly from registers in nearby FFMA instructions.

A recent CuTeDSL example shows exactly this: STL.128 stores to local memory immediately followed by FFMA using the same registers, with no obvious reloads in between [2]. A plausible cause suggested in-thread is conservative aliasing / overlap concerns (compiler can’t prove the store is irrelevant because it can’t fully prove what may alias what) [2].

Also, NVIDIA engineers have long emphasized an important split:

  • PTX is intermediate; optimizations (including dead-code elimination) may happen later in ptxas, and you should judge final behavior by SASS [3].
  • However, in your case you’re already looking at SASS and still seeing the stores, which usually means one of:
    • the backend is being conservative (aliasing, lifetime issues),
    • the store participates in some internal lowering pattern (e.g., preserving values across a region/edge the optimizer doesn’t fully reason about),
    • or it’s simply a backend missed-optimization / bug (these have been reported historically for redundant local loads/stores) [4].

3) Concrete knobs that usually decide whether it optimizes away

  • Make indices compile-time: if a loop induction variable controls fragment selection/addressing, switching to cute::static_for / unrolling often removes unexpected local traffic (your own observation aligns with this).
  • Prefer if constexpr for optional smem paths so unused code paths don’t exist in the instantiated kernel (CUTLASS recommends these patterns) [1].
  • Avoid debug/per-function compilation modes when judging DCE (e.g., -G, device debug) because they commonly inhibit optimization (this is a common explanation for “dead” locals in PTX/SASS discussions) [3].
  • If aliasing is plausible, try enforcing non-aliasing where applicable (CUDA C++: __restrict__ on raw pointers; with CuTe views it’s trickier, but the same principle applies). The forum reply explicitly raises aliasing as a reason the compiler may keep the stores [2].

If you want, paste a minimal CuTeDSL repro (kernel + compile flags + GPU arch + CUDA version) and I’ll point out exactly which parts prevent compile-time proving (and what to change so smem/lmem traffic is actually eliminated).

Sources: CUTLASS programming guidelines [1]; NVIDIA forum thread showing CuTeDSL dead-looking STL.128 in SASS and aliasing hypothesis [2]; NVIDIA forum explanation that PTX is intermediate and DCE happens in ptxas / check SASS and debug builds inhibit this [3]; historical report of redundant lmem loads/stores as compiler/backend artifact/bug [4].


Unconditional SMEM allocation for unused token buffers wastes memory for smaller batch sizes.

The Q/K/V shared memory buffers are allocated unconditionally at compile time for all four tokens, even when NUM_TOKENS=2 or 3. Comments indicate buffers are "only for NUM_TOKENS >= 3" and "only for NUM_TOKENS=4", but no compile-time conditionals guard these allocations. The subsequent gate computations at lines 1195+ use if NUM_TOKENS runtime checks, which cannot eliminate the earlier allocations. For NUM_TOKENS=2, this means allocating q_sh2, k_sh2, q_sh3, k_sh3, and partial v_sh buffers unnecessarily.

While this trade-off is acceptable for a unified kernel (avoiding code duplication and supporting multiple sequence lengths), the SMEM footprint could be reduced by wrapping allocations in if constexpr guards tied to the NUM_TOKENS parameter, allowing true compile-time elimination rather than runtime conditionals.

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1153 - 1165,
Unconditional shared-memory allocations q_sh2, k_sh2, q_sh3, k_sh3 and the extra
v_sh* buffers are always created even when NUM_TOKENS is smaller; wrap those
smem.allocate_tensor(...) calls in compile-time guards so unused buffers are
eliminated. Specifically, change the allocations for q_sh2/k_sh2 to be inside an
if constexpr (NUM_TOKENS >= 3) block and the allocations for q_sh3/k_sh3 to be
inside an if constexpr (NUM_TOKENS == 4) block (and similarly guard v_sh2/v_sh3
as appropriate), keeping the same smem.allocate_tensor(cutlass.Float32, 128)
calls and names so later code (gate computations) still references the same
identifiers when compiled in. Ensure the guards use the compile-time NUM_TOKENS
parameter so the compiler can drop unused allocations.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 5, 2026

@ameynaik-hub would you mind comparing with https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py ? It's better to unify the file location and interface.

@ameynaik-hub
Copy link
Contributor Author

@yzh119 compared benchmark here #2493 (comment)

https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py kernel supports
K-major and V-major h state layout. Also I believe it supports seqlen > 4.

on other hand, this kernel only support V-major (K contiguous , k -last dim) h state layout which is used for qwen models and fixed seqlen <=4 only.

Any suggestions on how to merge?

@vadiklyutiy
Copy link

@ameynaik-hub would you mind comparing with https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py ? It's better to unify the file location and interface.

and test as well

@vadiklyutiy
Copy link

@yzh119 compared benchmark here #2493 (comment)

I think it's worth to copy results here

@vadiklyutiy
Copy link

New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian?

@aditya-narayan5
Copy link

aditya-narayan5 commented Feb 5, 2026

@ameynaik-hub I might have overlooked details during reading but does anyone have a source for this:

V-major (K contiguous , k -last dim) h state layout which is used for qwen models

@ameynaik-hub
Copy link
Contributor Author

ameynaik-hub commented Feb 5, 2026

New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian?

@vadiklyutiy since it has support for T>4 (seqlen) and also support for v-last dim of h I thought it can be kept as is?

@ameynaik-hub
Copy link
Contributor Author

@vadiklyutiy
Copy link

New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian?

@vadiklyutiy since it has support for T>4 (seqlen) and also support for v-last dim of h I thought it can be kept as is?

As I remember there were several different interfaces/functions. I meant that processes case seq_len=1, k-last.
Of course I have not tried to propose unique interfaces, just want to avoid duplication

@yzh119
Copy link
Collaborator

yzh119 commented Feb 5, 2026

@ameynaik-hub can you reuse the gdn decode interface and set your implementation as default when we use k-layout (V-major) layout and T <=4?

Copy link
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ameynaik-hub this LGTM as the initial PR to just get the kernel into the Flashinfer codebase, approved. (still need a codeowner approval for this directory to merge ... maybe @yzh119 or @bkryu ). We can follow up with another PR for some work for the API integration + Flashinfer testing/benchmarking. Let me know if you have the bandwidth to work on this integration; if not, I can probably work on it next week.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 5, 2026

Hi @kahyunnam , I encourage not to push more code to flashinfer/cute_dsl anymore as we plan to categorize modules by functionalities, not sources. Other PRs are starting to remove codes out of flashinfer.cute_dsl.

The required work to unify the existing gdn decode and this PR should be minimal, I can help if you need any assitance with it.



def benchmark(
func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just curious why that wasn't addressed and pointed in #2370 but was here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmarking script in #2370 should use bench_gpu_time as well. I didn't notice that when reviewing that PR (it's my bad) and it was fixed in #2405.

Let's get it right in one shot in this PR.

else: # T == 4
launch_fn = gated_delta_rule_launch_seqlen4

_compiled_kernels[cache_key] = cute.compile(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please enable tvm-ffi with options="--enable-tvm-ffi",, otherwise the host side overhead of converting torch.tensor to dlpack would be non-neglibile.

Reference can be found at #2279

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks will do.

@ameynaik-hub
Copy link
Contributor Author

@ameynaik-hub can you reuse the gdn decode interface and set your implementation as default when we use k-layout (V-major) layout and T <=4?

@yzh119 yeah I can do that.

I noticed that my kernel utilizes bf16 inputs for the state h, which appears to be compatible with sglang and other models. However, it seems that the default state for h is fp32. I am investigating the possibility of optimizing this configuration for this specific case.

cc: @vadiklyutiy @kahyunnam

@yzh119
Copy link
Collaborator

yzh119 commented Feb 6, 2026

cc @guangyunh-nv (gdn prefill's author) do you have any insights on whether it's okay to use f16 state?

@guangyunh-nv
Copy link
Collaborator

FP16 state might not be a good choice due to the dynamic range limitation. The range is not guaranteed and purely depends on the model activation dynamics. So this purely depends on the model behavior.

Datapoint:
Previously with MiniMax M1 (lighting attention, similar to GDN but gating is not data dependent), the model only runs with BF16 due to its dynamical range grows exponentially somehow and then saturate.

BF16 maybe safer, but the accuracy would be a problem. If algorithm side decide that this is tolerable, then I think we should support it.

I think what we can do at the moment is to tolerate all kinds of state input, and learn from the production environment feedback :)

@xutizhou
Copy link
Contributor

xutizhou commented Feb 6, 2026

Have you tested the end-to-end accuracy with FP16 SSM state? Does the performance remain unchanged on common benchmarks such as MMLU, GSM8K, etc.?

@vadiklyutiy
Copy link

I think what we can do at the moment is to tolerate all kinds of state input, and learn from the production environment feedback :)

I don't think this is right approach. Imagine car manufacturer introduce new self driving car and say let learn from road environment. Don't think a lot of people would be happy.

@ameynaik-hub
Copy link
Contributor Author

ameynaik-hub commented Feb 11, 2026

cute dsl bf16 h perf on B200

batch T time(us)
1 1 3.62
1 2 5.86
1 3 6.94
1 4 7.79
4 1 4.9
4 2 6.62
4 3 7.65
4 4 8.64
8 1 7.04
8 2 8.67
8 3 9.95
8 4 11.39
16 1 9.74
16 2 12.9
16 3 15.15
16 4 17.54
32 1 16.06
32 2 21.57
32 3 25.79
32 4 28.61
64 1 27.01
64 2 34.88
64 3 40.67
64 4 49.76
128 1 48.9
128 2 60.7
128 3 72.26
128 4 89.04
256 1 91.66
256 2 112.54
256 3 134.88
256 4 166.32
512 1 177.06
512 2 214.53
512 3 258.75
512 4 320.8

@kahyunnam
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !309 has been created, and the CI pipeline #43822304 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43822304: 11/20 passed

- Bank-conflict-free cross-warp reductions
- Async H memory loading with aggressive pipelining
- BF16 tensors with FP32 compute for numerical stability
- GQA (grouped-query attention) support with configurable H (query) and HV (value) heads
Copy link
Collaborator

@guangyunh-nv guangyunh-nv Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GQA: HQ = <IntegerRatio> * HV
GVA: <IntegerRatio> * HQ = HV, I confirmed this naming in a group chat with GDN paper author :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyunh-nv sorry do you mean it is called GVA and not GQA?

@aditya-narayan5
Copy link

Thanks to everyone for clarifying, this helped :) @vadiklyutiy @guangyunh-nv @ameynaik-hub

@guangyunh-nv
Copy link
Collaborator

Out of curisity, why the perf degenerate when batch goes from 1 to 4.

batch T time(us)
1 1 3.62
4 1 4.9

I'd assume you have not reached the parallelism limit and bandwidth limit, then the time should remain a constant.

@ameynaik-hub
Copy link
Contributor Author

Out of curisity, why the perf degenerate when batch goes from 1 to 4.

@guangyunh-nv for BS <=4 I use 4 ctas per head. so for BS=1 it has 324 = 128 CTAs and for BS=4 it has 324*4 = 512 CTAs
at this regime it is latency bound, MIO throttle appears to be the reason based on ncu reports.

@@ -137,6 +137,7 @@ def blockwise_linear_attention(
| torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads
decay_exponent_offset=0,
kv_dtype: torch.dtype = torch.float32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should remove kv_dtype, it is named as kv_dtype because in some earlier papers call the state directly as KV. It is not dtype for K and V

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, can you please confirm if the change is okay? Thanks!

- Consolidate to single state_dtype parameter across all reference functions
  - Remove duplicate kv_dtype parameter from blockwise_linear_attention(),
    delta_rule(), and blockwise_delta_rule()
  - Update test_prefill_delta_rule.py to use state_dtype consistently

- Remove benchmark_gated_delta_rule.py from git tracking (keep locally)
  - Add to .gitignore for local development use only

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, some minor nits.

@vadiklyutiy
Copy link

I would like to speed up this review a bit. The Qwen3.5 headline model has been released today.

ameynaik-hub and others added 2 commits February 16, 2026 11:28
Add validation for required tensor parameters to fail early with clear error
messages. Expand cache key to include all shape dimensions (H, HV, K, V) to
prevent incorrect kernel reuse when shapes change.

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
Signed-off-by: Amey Naik <[email protected]>
Relocate gated_delta_rule.py from flashinfer/cute_dsl/ to a new
flashinfer/gdn_kernels/ module to improve code organization and clarify
the kernel's domain-specific purpose.

Changes:
- Create flashinfer/gdn_kernels/ module for GDN-specific CuTe DSL kernels
- Rename gated_delta_rule.py to gdn_decode_bf16_state.py for clarity
  (indicates BF16 hidden state variant)
- Update all 3 import sites to use new path:
  - flashinfer/gdn_decode.py
  - benchmarks/bench_gdn_decode.py
  - tests/gdn/test_decode_delta_rule.py
- Add module __init__.py with proper re-exports
- Avoid namespace conflict with existing gdn_decode.py file

The flashinfer/cute_dsl/ directory remains for cross-cutting CuTe DSL
utilities (RMSNorm, FP4, GEMM+AllReduce, etc.).

All 32 GDN decode CuTe DSL tests pass successfully.

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
Signed-off-by: Amey Naik <[email protected]>
@ameynaik-hub
Copy link
Contributor Author

ameynaik-hub commented Feb 16, 2026

I have made some improvement for fp32 h state version of gdn_decode

BS,FI-PreTr (us),BS32opt-V2 (us),Speedup
1,3.74,3.52,1.06x
4,5.31,5.31,1.00x
8,7.68,7.58,1.01x
16,13.39,12.74,1.05x
32,22.99,22.72,1.01x
64,54.85,42.38,1.29x
128,92.21,81.39,1.13x
256,172.64,158.62,1.09x
512,336.56,313.69,1.07x

Zihao, I am thinking of creating a seperate PR for it. what do you think? @yzh119

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ameynaik-hub thanks, I think we can create another PR for the f32 acceleration. This PR itself is in good shape.

@yzh119 yzh119 merged commit f69875a into flashinfer-ai:main Feb 17, 2026
19 checks passed
djmmoss added a commit to djmmoss/flashinfer that referenced this pull request Feb 17, 2026
Add KDA (Key-Driven Attention) decode support as a CuTe DSL kernel,
extending the GDN decode kernel from PR flashinfer-ai#2498 to support per-key-dimension
gating. KDA generalizes GDN's scalar gate (g in R^1) to per-K gating
(g in R^K), with the gate mapping naturally to the warp structure.

Changes:
- Extract shared gate-independent helpers from GDN kernel into
  flashinfer/gdn_kernels/_common.py (~290 lines), slimming
  gdn_decode_bf16_state.py. No GDN behavior change.
- Add HEAD_DIM=64 support to GDN dispatch (previously 128 only)
- Preserve lowBS_1chunk kernel variants for B<=4 (both GDN and KDA)
- New flashinfer/kda_kernels/ module with T=1-4 kernels for
  HEAD_DIM={64,128}, plus chunk_kda-compatible wrapper
- 80 KDA tests covering correctness, state updates, GDN reduction
- KDA decode benchmark

Tested on B200 (SM100) with CUDA 12.9. BF16 storage, FP32 compute.
GDN: 138/138 tests pass, no performance regression.
KDA: 80/80 tests pass.

AI-assisted (Claude Code)
@yzh119 yzh119 mentioned this pull request Feb 17, 2026
10 tasks
djmmoss added a commit to djmmoss/flashinfer that referenced this pull request Feb 26, 2026
Add KDA (Key-Driven Attention) decode support as a CuTe DSL kernel,
extending the GDN decode kernel from PR flashinfer-ai#2498 to support per-key-dimension
gating. KDA generalizes GDN's scalar gate (g in R^1) to per-K gating
(g in R^K), with the gate mapping naturally to the warp structure.

Changes:
- Extract shared gate-independent helpers from GDN kernel into
  flashinfer/gdn_kernels/_common.py (~290 lines), slimming
  gdn_decode_bf16_state.py. No GDN behavior change.
- Add HEAD_DIM=64 support to GDN dispatch (previously 128 only)
- Preserve lowBS_1chunk kernel variants for B<=4 (both GDN and KDA)
- New flashinfer/kda_kernels/ module with T=1-4 kernels for
  HEAD_DIM={64,128}, plus chunk_kda-compatible wrapper
- 80 KDA tests covering correctness, state updates, GDN reduction
- KDA decode benchmark

Tested on B200 (SM100) with CUDA 12.9. BF16 storage, FP32 compute.
GDN: 138/138 tests pass, no performance regression.
KDA: 80/80 tests pass.

AI-assisted (Claude Code)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants