Commit f69875a
Ameyn/gdn decode cutedsl kernel (#2498)
Implements high-performance Gated Delta Rule linear attention kernel
supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA
CuTe-DSL.
Key features:
- H state layout: K-last [B, HV, V, K] where K is the contiguous
(fastest) dimension
- L2-normalized Q/K with configurable scale
- Gated exponential decay via softplus
- Delta rule updates: v_delta = beta * (v - pred)
- Async H memory loading with aggressive pipelining
- BF16 tensors with FP32 compute
Also includes:
- benchmark_gated_delta_rule.py: Simple benchmark script for measuring
kernel perf
- Updated __init__.py exports
<!-- .github/pull_request_template.md -->
## 📌 Description
Implements high-performance Gated Delta Rule linear attention kernel
supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA
CuTe-DSL.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added CUDA-accelerated Gated Delta Rule kernel with optimized paths
for sequence lengths 1–4 and new public API entries gated_delta_rule and
GatedDeltaRuleKernel (exported when available).
* **Benchmarks**
* New benchmark for the gated_delta_rule kernel and integration of an
improved CuTe-DSL variant into the benchmark suite and CLI, with per‑T
summaries and performance reporting.
* **Tests**
* End-to-end tests exercising the improved kernel for T=1–4,
availability guards, and updated reference tests to support explicit
state storage dtype handling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>1 parent a2d6d49 commit f69875a
9 files changed
Lines changed: 2925 additions & 85 deletions
File tree
- benchmarks
- flashinfer
- cute_dsl
- gdn_kernels
- tests/gdn
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| 22 | + | |
22 | 23 | | |
23 | 24 | | |
24 | 25 | | |
| |||
0 commit comments