Ameyn/gdn decode cutedsl kernel#2498
Conversation
Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL. Key features: - H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension - Unified kernel architecture: T=2/3/4 share a single compile-time specialized kernel via Constexpr dispatch; T=1 uses separate kernel with persistent K optimization - L2-normalized Q/K with configurable scale - Gated exponential decay via softplus - Delta rule updates: v_delta = beta * (v - pred) - Bank-conflict-free cross-warp reductions - Async H memory loading with aggressive pipelining - BF16 tensors with FP32 compute for numerical stability - GQA (grouped-query attention) support Also includes: - benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf - Updated __init__.py exports Signed-off-by: Amey Naik <[email protected]>
Signed-off-by: Amey Naik <[email protected]>
Summary of ChangesHello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a highly optimized CUDA kernel for the Gated Delta Rule linear attention mechanism, specifically tailored for decode-phase inference. The implementation, built with NVIDIA CuTe-DSL, provides high performance for fixed sequence lengths of 1, 2, 3, and 4. It incorporates advanced optimizations such as a K-last H state layout, L2-normalized Q/K, gated exponential decay, delta rule updates, and aggressive asynchronous memory pipelining, all while maintaining numerical stability with BF16 tensors and FP32 compute. A new benchmark script is also included to evaluate the kernel's performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a CuTe-DSL Gated Delta Rule CUDA kernel (seq_len 1–4) with a Python launcher/class, a standalone benchmark, benchmark integration, reference dtype/state handling updates, and tests exercising the new kernel and BF16 state paths. Changes
Sequence Diagram(s)sequenceDiagram
participant PythonAPI as Python API
participant KernelCache as Kernel Cache
participant CuTe as CuTe Compiler
participant GPU as GPU Kernel
participant Memory as GPU Memory
PythonAPI->>PythonAPI: validate inputs (q,k,v,A_log,a,dt_bias,seq_len)
PythonAPI->>KernelCache: lookup compiled kernel (T, dtypes)
alt cached
KernelCache-->>PythonAPI: return cached kernel
else not cached
PythonAPI->>CuTe: compile kernel for (T, dtypes)
CuTe-->>KernelCache: store compiled kernel
KernelCache-->>PythonAPI: return compiled kernel
end
PythonAPI->>GPU: launch kernel on CUDA stream with args
GPU->>Memory: async load Q,K,V,gates,state blocks
GPU->>GPU: l2-normalize, gated decay, delta-rule updates, reductions
GPU->>Memory: write output [B,T,HV,V] and updated state
GPU-->>PythonAPI: signal completion / return tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance Gated Delta Rule linear attention kernel using CuTe-DSL, supporting sequence lengths from 1 to 4. The implementation includes separate, optimized kernels for T=1 and a unified kernel for T=2,3,4, along with a benchmark script. My review focuses on the new kernel implementation in gated_delta_rule.py. I've identified a few areas for improvement: the seqlen=1 kernel contains significant code duplication that could be refactored for better maintainability. The kernel caching strategy could lead to performance issues with dynamic batch sizes due to unnecessary recompilations. Finally, the exported GatedDeltaRuleKernel class appears to be incomplete or unused. Addressing these points will improve the robustness and performance of the new kernel.
| class GatedDeltaRuleKernel: | ||
| """ | ||
| Gated Delta Rule Kernel for linear attention decode. | ||
|
|
||
| This kernel implements the Gated Delta Rule mechanism supporting sequence | ||
| lengths T=1, T=2, T=3, T=4 with optimized CUDA implementations. | ||
|
|
||
| Key features: | ||
| - T=1: Persistent K in registers with aggressive pipelining | ||
| - T=2/3/4: Unified kernel with compile-time Constexpr specialization | ||
| - L2-normalized Q/K with configurable scale | ||
| - Gated exponential decay via softplus | ||
| - Bank-conflict-free cross-warp reductions | ||
| - Async H memory loading | ||
|
|
||
| Args: | ||
| seq_len: Sequence length (1, 2, 3, or 4) | ||
| """ | ||
|
|
||
| def __init__(self, seq_len: int): | ||
| assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}" | ||
| self.seq_len = seq_len | ||
| self._compiled_kernel = None | ||
|
|
||
| def _get_launch_fn(self): | ||
| if self.seq_len == 1: | ||
| return gated_delta_rule_launch_seqlen1 | ||
| elif self.seq_len == 2: | ||
| return gated_delta_rule_launch_seqlen2 | ||
| elif self.seq_len == 3: | ||
| return gated_delta_rule_launch_seqlen3 | ||
| else: | ||
| return gated_delta_rule_launch_seqlen4 |
There was a problem hiding this comment.
The GatedDeltaRuleKernel class is defined and exported as part of the public API, but it appears to be incomplete and is not used by the main gated_delta_rule function. The _compiled_kernel member is initialized to None and is never assigned a compiled kernel, and there is no method to execute it. The functional entry point gated_delta_rule implements its own caching and launch logic.
If this class is intended for future use, it should be fully implemented. If it is obsolete or a work-in-progress, it should either be completed or removed to avoid confusing users of the library.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1773-1774: The cache_key currently set to (T, B) is too narrow and
can cause incorrect kernel reuse; update the cache key creation (variable
cache_key) to include the relevant tensor dimensions such as H, HV, K, and V
(the head count and per-head dims) so it uniquely identifies tensor shapes used
by the kernel—derive these from the involved tensors (the q/k/v/proj shapes or
whatever variables represent heads and head sizes) and include them alongside T
and B when building cache_key to prevent shape-mismatch cache hits.
- Around line 1666-1680: GatedDeltaRuleKernel currently only stores seq_len and
_compiled_kernel and provides _get_launch_fn but is never used by
gated_delta_rule or _compiled_kernels; either finish the class by adding a
public execution API (e.g., an __call__ or execute method that compiles/looks up
the kernel into _compiled_kernel, uses _get_launch_fn to obtain the launch
function and runs it with the same signature as the module-level implementation)
or remove the class from the public surface and migrate any kernel-caching logic
into the existing module-level _compiled_kernels paths; reference the class name
GatedDeltaRuleKernel, its attribute _compiled_kernel, method _get_launch_fn, and
the public gated_delta_rule/_compiled_kernels cache when making the change.
- Around line 1746-1749: The code accesses q, k, v, b and initial_state_source
without null checks even though their annotations allow None; add explicit
validation at the start of the function (e.g., in the function containing the
lines that reference q.shape and v.shape) to raise a clear ValueError if any of
q, k, v, b, or initial_state_source is None (or alternatively update the
function signature to remove Optional typing for these parameters), and ensure
any downstream use of from_dlpack(initial_state_source, ...) only occurs after
confirming initial_state_source is not None; reference the variables q, k, v, b
and initial_state_source and the shapes accessed (q.shape, v.shape[2],
v.shape[3]) so the checks are placed before those accesses.
- Around line 1153-1165: Unconditional shared-memory allocations q_sh2, k_sh2,
q_sh3, k_sh3 and the extra v_sh* buffers are always created even when NUM_TOKENS
is smaller; wrap those smem.allocate_tensor(...) calls in compile-time guards so
unused buffers are eliminated. Specifically, change the allocations for
q_sh2/k_sh2 to be inside an if constexpr (NUM_TOKENS >= 3) block and the
allocations for q_sh3/k_sh3 to be inside an if constexpr (NUM_TOKENS == 4) block
(and similarly guard v_sh2/v_sh3 as appropriate), keeping the same
smem.allocate_tensor(cutlass.Float32, 128) calls and names so later code (gate
computations) still references the same identifiers when compiled in. Ensure the
guards use the compile-time NUM_TOKENS parameter so the compiler can drop unused
allocations.
🧹 Nitpick comments (5)
flashinfer/cute_dsl/gated_delta_rule.py (3)
118-136: Potential numerical stability concern incompute_single_gate.The softplus implementation handles large positive values but not large negative values of
beta_x. Whenbeta_xis very negative,cute.math.exp(beta_x)approaches zero, which is fine numerically. However, the sigmoid computation at line 134 could have issues with large positivebeta_rawvalues wherecute.math.exp(-beta_raw)underflows to 0, resulting inbeta = 1.0(acceptable), but large negativebeta_rawcausesexp(-beta_raw)to overflow.Consider adding a symmetric threshold check for the sigmoid similar to the softplus handling.
💡 Optional: Add numerical safeguard for sigmoid
g = -cute.math.exp(A_log_val) * softplus_x g_exp = cute.math.exp(g) - beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw)) + # Numerically stable sigmoid + if beta_raw >= cutlass.Float32(0.0): + beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw)) + else: + exp_beta = cute.math.exp(beta_raw) + beta = exp_beta / (cutlass.Float32(1.0) + exp_beta) return g_exp, beta
306-354: Consider removing unusedo_headparameter.The
o_headparameter is passed but never used inprocess_first_token. The first token's output is returned and stored later during subsequent token processing. Removing it would clarify the function's contract.♻️ Remove unused parameter
`@cute.jit` def process_first_token( h_sh_chunk_curr, h_chunk, kq_chunk, k_sh, q_sh, v_sh, reduce_sh, - o_head, g_exp, beta, v_offset, pred_slot, warp_idx, lane_idx, k_base, ):Note: This would require updating all call sites in
process_vchunk_unified_234.
1746-1746: Unpacked variableHis unused.The variable
His extracted fromq.shapebut never used in the function. Consider prefixing with underscore to indicate intentional non-use.♻️ Prefix unused variable
- B, T, H, K = q.shape + B, T, _H, K = q.shapeflashinfer/cute_dsl/benchmark_gated_delta_rule.py (2)
82-82: Import path may fail depending on execution context.The relative import
from gated_delta_rule import gated_delta_ruleassumes the script is run from within theflashinfer/cute_dsl/directory or that directory is insys.path. This will fail if run from the repository root or viapython -m.Consider using absolute import for robustness.
♻️ Use absolute import
def main(): - from gated_delta_rule import gated_delta_rule + from flashinfer.cute_dsl.gated_delta_rule import gated_delta_rule
116-127: Loop variable capture in closure - latent bug risk.The
run_kernelfunction capturesinputsandstateby reference, not by value. While this works correctly here because the closure is consumed immediately viabenchmark(), it's a latent bug that would manifest if the code were refactored (e.g., collecting closures to run later).♻️ Bind loop variables explicitly
- def run_kernel(): + def run_kernel(inputs=inputs, state=state): return gated_delta_rule( A_log=inputs["A_log"], a=inputs["a"], dt_bias=inputs["dt_bias"], q=inputs["q"], k=inputs["k"], v=inputs["v"], b=inputs["b"], initial_state_source=state, scale=inputs["scale"], )
| # Q/K buffers for token 2 (only for NUM_TOKENS >= 3) | ||
| q_sh2 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| k_sh2 = smem.allocate_tensor(cutlass.Float32, 128) | ||
|
|
||
| # Q/K buffers for token 3 (only for NUM_TOKENS=4) | ||
| q_sh3 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| k_sh3 = smem.allocate_tensor(cutlass.Float32, 128) | ||
|
|
||
| # V buffers | ||
| v_sh0 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| v_sh1 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| v_sh2 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| v_sh3 = smem.allocate_tensor(cutlass.Float32, 128) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp && git ls-files flashinfer/cute_dsl/gated_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 138
🏁 Script executed:
wc -l flashinfer/cute_dsl/gated_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 112
🏁 Script executed:
sed -n '1140,1180p' flashinfer/cute_dsl/gated_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 1718
🏁 Script executed:
# Search for NUM_TOKENS usage and conditional patterns around this section
sed -n '1100,1200p' flashinfer/cute_dsl/gated_delta_rule.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3977
🏁 Script executed:
# Look for if statements guarding buffer allocation
rg -n "if.*NUM_TOKENS" flashinfer/cute_dsl/gated_delta_rule.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 357
🌐 Web query:
CuTe DSL Constexpr shared memory allocation dead code elimination
💡 Result:
In CuTe / CuTeDSL (CUTLASS), whether “shared-memory allocation” and the associated copies/loops disappear is mostly a question of whether the compiler can prove things at compile time (i.e., via types and constexpr / if constexpr) and whether the backend can safely apply dead-store / dead-code elimination across the generated code.
1) “constexpr shared memory allocation” in practice (CuTe/CUTLASS)
- CuTe is designed so that tile shapes, layouts, and smem footprints are carried in the type system; when those are compile-time constants, the compiler can fully specialize code (unroll, constant-fold addresses, remove unused objects/paths).
- Practically, that means: prefer CuTe’s compile-time control-flow (
if constexpr,cute::static_for, type-based shapes/layouts) rather than runtime loops/conditions if you expect code (and sometimes even smem objects) to vanish. - CUTLASS’ own guidelines explicitly lean on
if constexpr-based exhaustiveness and “unreachable” patterns to ensure unwanted branches don’t survive compilation (and to avoid “missing return” / unreachable-path issues) [1].
2) Why you can see “dead-looking” stores (local or shared) even at -O3
What you’re describing matches a real pattern people have observed with CuTeDSL-generated kernels: stores to .local (stack) that appear unused even though the values are consumed directly from registers in nearby FFMA instructions.
A recent CuTeDSL example shows exactly this: STL.128 stores to local memory immediately followed by FFMA using the same registers, with no obvious reloads in between [2]. A plausible cause suggested in-thread is conservative aliasing / overlap concerns (compiler can’t prove the store is irrelevant because it can’t fully prove what may alias what) [2].
Also, NVIDIA engineers have long emphasized an important split:
- PTX is intermediate; optimizations (including dead-code elimination) may happen later in
ptxas, and you should judge final behavior by SASS [3]. - However, in your case you’re already looking at SASS and still seeing the stores, which usually means one of:
- the backend is being conservative (aliasing, lifetime issues),
- the store participates in some internal lowering pattern (e.g., preserving values across a region/edge the optimizer doesn’t fully reason about),
- or it’s simply a backend missed-optimization / bug (these have been reported historically for redundant local loads/stores) [4].
3) Concrete knobs that usually decide whether it optimizes away
- Make indices compile-time: if a loop induction variable controls fragment selection/addressing, switching to
cute::static_for/ unrolling often removes unexpected local traffic (your own observation aligns with this). - Prefer
if constexprfor optional smem paths so unused code paths don’t exist in the instantiated kernel (CUTLASS recommends these patterns) [1]. - Avoid debug/per-function compilation modes when judging DCE (e.g.,
-G, device debug) because they commonly inhibit optimization (this is a common explanation for “dead” locals in PTX/SASS discussions) [3]. - If aliasing is plausible, try enforcing non-aliasing where applicable (CUDA C++:
__restrict__on raw pointers; with CuTe views it’s trickier, but the same principle applies). The forum reply explicitly raises aliasing as a reason the compiler may keep the stores [2].
If you want, paste a minimal CuTeDSL repro (kernel + compile flags + GPU arch + CUDA version) and I’ll point out exactly which parts prevent compile-time proving (and what to change so smem/lmem traffic is actually eliminated).
Sources: CUTLASS programming guidelines [1]; NVIDIA forum thread showing CuTeDSL dead-looking STL.128 in SASS and aliasing hypothesis [2]; NVIDIA forum explanation that PTX is intermediate and DCE happens in ptxas / check SASS and debug builds inhibit this [3]; historical report of redundant lmem loads/stores as compiler/backend artifact/bug [4].
Unconditional SMEM allocation for unused token buffers wastes memory for smaller batch sizes.
The Q/K/V shared memory buffers are allocated unconditionally at compile time for all four tokens, even when NUM_TOKENS=2 or 3. Comments indicate buffers are "only for NUM_TOKENS >= 3" and "only for NUM_TOKENS=4", but no compile-time conditionals guard these allocations. The subsequent gate computations at lines 1195+ use if NUM_TOKENS runtime checks, which cannot eliminate the earlier allocations. For NUM_TOKENS=2, this means allocating q_sh2, k_sh2, q_sh3, k_sh3, and partial v_sh buffers unnecessarily.
While this trade-off is acceptable for a unified kernel (avoiding code duplication and supporting multiple sequence lengths), the SMEM footprint could be reduced by wrapping allocations in if constexpr guards tied to the NUM_TOKENS parameter, allowing true compile-time elimination rather than runtime conditionals.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1153 - 1165,
Unconditional shared-memory allocations q_sh2, k_sh2, q_sh3, k_sh3 and the extra
v_sh* buffers are always created even when NUM_TOKENS is smaller; wrap those
smem.allocate_tensor(...) calls in compile-time guards so unused buffers are
eliminated. Specifically, change the allocations for q_sh2/k_sh2 to be inside an
if constexpr (NUM_TOKENS >= 3) block and the allocations for q_sh3/k_sh3 to be
inside an if constexpr (NUM_TOKENS == 4) block (and similarly guard v_sh2/v_sh3
as appropriate), keeping the same smem.allocate_tensor(cutlass.Float32, 128)
calls and names so later code (gate computations) still references the same
identifiers when compiled in. Ensure the guards use the compile-time NUM_TOKENS
parameter so the compiler can drop unused allocations.
|
@ameynaik-hub would you mind comparing with https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py ? It's better to unify the file location and interface. |
|
@yzh119 compared benchmark here #2493 (comment) https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py kernel supports on other hand, this kernel only support V-major (K contiguous , k -last dim) Any suggestions on how to merge? |
and test as well |
I think it's worth to copy results here |
|
New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian? |
|
@ameynaik-hub I might have overlooked details during reading but does anyone have a source for this:
|
@vadiklyutiy since it has support for T>4 (seqlen) and also support for v-last dim of |
As I remember there were several different interfaces/functions. I meant that processes case seq_len=1, k-last. |
|
@ameynaik-hub can you reuse the gdn decode interface and set your implementation as default when we use k-layout (V-major) layout and T <=4? |
There was a problem hiding this comment.
@ameynaik-hub this LGTM as the initial PR to just get the kernel into the Flashinfer codebase, approved. (still need a codeowner approval for this directory to merge ... maybe @yzh119 or @bkryu ). We can follow up with another PR for some work for the API integration + Flashinfer testing/benchmarking. Let me know if you have the bandwidth to work on this integration; if not, I can probably work on it next week.
|
Hi @kahyunnam , I encourage not to push more code to The required work to unify the existing gdn decode and this PR should be minimal, I can help if you need any assitance with it. |
|
|
||
|
|
||
| def benchmark( | ||
| func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True |
There was a problem hiding this comment.
We already have benchmarking APIs https://docs.flashinfer.ai/generated/flashinfer.testing.bench_gpu_time.html#flashinfer.testing.bench_gpu_time, please refer to https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_gdn_decode.py#L1088-L1095 on how to use these APIs.
There was a problem hiding this comment.
I just curious why that wasn't addressed and pointed in #2370 but was here?
| else: # T == 4 | ||
| launch_fn = gated_delta_rule_launch_seqlen4 | ||
|
|
||
| _compiled_kernels[cache_key] = cute.compile( |
There was a problem hiding this comment.
please enable tvm-ffi with options="--enable-tvm-ffi",, otherwise the host side overhead of converting torch.tensor to dlpack would be non-neglibile.
Reference can be found at #2279
There was a problem hiding this comment.
Thanks will do.
@yzh119 yeah I can do that. I noticed that my kernel utilizes bf16 inputs for the state |
|
cc @guangyunh-nv (gdn prefill's author) do you have any insights on whether it's okay to use f16 state? |
|
FP16 state might not be a good choice due to the dynamic range limitation. The range is not guaranteed and purely depends on the model activation dynamics. So this purely depends on the model behavior. Datapoint: BF16 maybe safer, but the accuracy would be a problem. If algorithm side decide that this is tolerable, then I think we should support it. I think what we can do at the moment is to tolerate all kinds of state input, and learn from the production environment feedback :) |
|
Have you tested the end-to-end accuracy with FP16 SSM state? Does the performance remain unchanged on common benchmarks such as MMLU, GSM8K, etc.? |
I don't think this is right approach. Imagine car manufacturer introduce new self driving car and say let learn from road environment. Don't think a lot of people would be happy. |
|
cute dsl bf16
|
|
/bot run |
|
[FAILED] Pipeline #43822304: 11/20 passed |
| - Bank-conflict-free cross-warp reductions | ||
| - Async H memory loading with aggressive pipelining | ||
| - BF16 tensors with FP32 compute for numerical stability | ||
| - GQA (grouped-query attention) support with configurable H (query) and HV (value) heads |
There was a problem hiding this comment.
GQA: HQ = <IntegerRatio> * HV
GVA: <IntegerRatio> * HQ = HV, I confirmed this naming in a group chat with GDN paper author :)
There was a problem hiding this comment.
@guangyunh-nv sorry do you mean it is called GVA and not GQA?
|
Thanks to everyone for clarifying, this helped :) @vadiklyutiy @guangyunh-nv @ameynaik-hub |
|
Out of curisity, why the perf degenerate when batch goes from 1 to 4.
I'd assume you have not reached the parallelism limit and bandwidth limit, then the time should remain a constant. |
@guangyunh-nv for BS <=4 I use 4 ctas per head. so for BS=1 it has 324 = 128 CTAs and for BS=4 it has 324*4 = 512 CTAs |
tests/gdn/reference_delta_rule.py
Outdated
| @@ -137,6 +137,7 @@ def blockwise_linear_attention( | |||
| | torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads | |||
| decay_exponent_offset=0, | |||
| kv_dtype: torch.dtype = torch.float32, | |||
There was a problem hiding this comment.
I think you should remove kv_dtype, it is named as kv_dtype because in some earlier papers call the state directly as KV. It is not dtype for K and V
There was a problem hiding this comment.
done, can you please confirm if the change is okay? Thanks!
- Consolidate to single state_dtype parameter across all reference functions
- Remove duplicate kv_dtype parameter from blockwise_linear_attention(),
delta_rule(), and blockwise_delta_rule()
- Update test_prefill_delta_rule.py to use state_dtype consistently
- Remove benchmark_gated_delta_rule.py from git tracking (keep locally)
- Add to .gitignore for local development use only
Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
yzh119
left a comment
There was a problem hiding this comment.
Overall LGTM, some minor nits.
Co-authored-by: Zihao Ye <[email protected]>
|
I would like to speed up this review a bit. The Qwen3.5 headline model has been released today. |
Add validation for required tensor parameters to fail early with clear error messages. Expand cache key to include all shape dimensions (H, HV, K, V) to prevent incorrect kernel reuse when shapes change. Co-Authored-By: Claude Sonnet 4.5 <[email protected]> Signed-off-by: Amey Naik <[email protected]>
Relocate gated_delta_rule.py from flashinfer/cute_dsl/ to a new flashinfer/gdn_kernels/ module to improve code organization and clarify the kernel's domain-specific purpose. Changes: - Create flashinfer/gdn_kernels/ module for GDN-specific CuTe DSL kernels - Rename gated_delta_rule.py to gdn_decode_bf16_state.py for clarity (indicates BF16 hidden state variant) - Update all 3 import sites to use new path: - flashinfer/gdn_decode.py - benchmarks/bench_gdn_decode.py - tests/gdn/test_decode_delta_rule.py - Add module __init__.py with proper re-exports - Avoid namespace conflict with existing gdn_decode.py file The flashinfer/cute_dsl/ directory remains for cross-cutting CuTe DSL utilities (RMSNorm, FP4, GEMM+AllReduce, etc.). All 32 GDN decode CuTe DSL tests pass successfully. Co-Authored-By: Claude Sonnet 4.5 <[email protected]> Signed-off-by: Amey Naik <[email protected]>
|
I have made some improvement for fp32 h state version of gdn_decode BS,FI-PreTr (us),BS32opt-V2 (us),Speedup Zihao, I am thinking of creating a seperate PR for it. what do you think? @yzh119 |
yzh119
left a comment
There was a problem hiding this comment.
Hi @ameynaik-hub thanks, I think we can create another PR for the f32 acceleration. This PR itself is in good shape.
Add KDA (Key-Driven Attention) decode support as a CuTe DSL kernel, extending the GDN decode kernel from PR flashinfer-ai#2498 to support per-key-dimension gating. KDA generalizes GDN's scalar gate (g in R^1) to per-K gating (g in R^K), with the gate mapping naturally to the warp structure. Changes: - Extract shared gate-independent helpers from GDN kernel into flashinfer/gdn_kernels/_common.py (~290 lines), slimming gdn_decode_bf16_state.py. No GDN behavior change. - Add HEAD_DIM=64 support to GDN dispatch (previously 128 only) - Preserve lowBS_1chunk kernel variants for B<=4 (both GDN and KDA) - New flashinfer/kda_kernels/ module with T=1-4 kernels for HEAD_DIM={64,128}, plus chunk_kda-compatible wrapper - 80 KDA tests covering correctness, state updates, GDN reduction - KDA decode benchmark Tested on B200 (SM100) with CUDA 12.9. BF16 storage, FP32 compute. GDN: 138/138 tests pass, no performance regression. KDA: 80/80 tests pass. AI-assisted (Claude Code)
Add KDA (Key-Driven Attention) decode support as a CuTe DSL kernel, extending the GDN decode kernel from PR flashinfer-ai#2498 to support per-key-dimension gating. KDA generalizes GDN's scalar gate (g in R^1) to per-K gating (g in R^K), with the gate mapping naturally to the warp structure. Changes: - Extract shared gate-independent helpers from GDN kernel into flashinfer/gdn_kernels/_common.py (~290 lines), slimming gdn_decode_bf16_state.py. No GDN behavior change. - Add HEAD_DIM=64 support to GDN dispatch (previously 128 only) - Preserve lowBS_1chunk kernel variants for B<=4 (both GDN and KDA) - New flashinfer/kda_kernels/ module with T=1-4 kernels for HEAD_DIM={64,128}, plus chunk_kda-compatible wrapper - 80 KDA tests covering correctness, state updates, GDN reduction - KDA decode benchmark Tested on B200 (SM100) with CUDA 12.9. BF16 storage, FP32 compute. GDN: 138/138 tests pass, no performance regression. KDA: 80/80 tests pass. AI-assisted (Claude Code)
Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.
Key features:
Also includes:
📌 Description
Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit
New Features
Benchmarks
Tests