Skip to content

Commit f69875a

Browse files
ameynaik-hubclaudeyzh119
authored
Ameyn/gdn decode cutedsl kernel (#2498)
Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL. Key features: - H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension - L2-normalized Q/K with configurable scale - Gated exponential decay via softplus - Delta rule updates: v_delta = beta * (v - pred) - Async H memory loading with aggressive pipelining - BF16 tensors with FP32 compute Also includes: - benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf - Updated __init__.py exports <!-- .github/pull_request_template.md --> ## 📌 Description Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added CUDA-accelerated Gated Delta Rule kernel with optimized paths for sequence lengths 1–4 and new public API entries gated_delta_rule and GatedDeltaRuleKernel (exported when available). * **Benchmarks** * New benchmark for the gated_delta_rule kernel and integration of an improved CuTe-DSL variant into the benchmark suite and CLI, with per‑T summaries and performance reporting. * **Tests** * End-to-end tests exercising the improved kernel for T=1–4, availability guards, and updated reference tests to support explicit state storage dtype handling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
1 parent a2d6d49 commit f69875a

9 files changed

Lines changed: 2925 additions & 85 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ csrc/aot_default_additional_params.h
1919

2020
# Microbenchmark files
2121
microbenchmark/
22+
flashinfer/cute_dsl/benchmark_gated_delta_rule.py
2223

2324
# vscode
2425
.vscode/

0 commit comments

Comments
 (0)