Skip to content

Commit 9d01f07

Browse files
committed
feat(kda): add KDA decode CuTe DSL kernel with per-K gating
Add KDA (Key-Driven Attention) decode support as a CuTe DSL kernel, extending the GDN decode kernel from PR flashinfer-ai#2498 to support per-key-dimension gating. KDA generalizes GDN's scalar gate (g in R^1) to per-K gating (g in R^K), with the gate mapping naturally to the warp structure. Changes: - Extract shared gate-independent helpers from GDN kernel into flashinfer/gdn_kernels/_common.py (~290 lines), slimming gdn_decode_bf16_state.py. No GDN behavior change. - Add HEAD_DIM=64 support to GDN dispatch (previously 128 only) - Preserve lowBS_1chunk kernel variants for B<=4 (both GDN and KDA) - New flashinfer/kda_kernels/ module with T=1-4 kernels for HEAD_DIM={64,128}, plus chunk_kda-compatible wrapper - 80 KDA tests covering correctness, state updates, GDN reduction - KDA decode benchmark Tested on B200 (SM100) with CUDA 12.9. BF16 storage, FP32 compute. GDN: 138/138 tests pass, no performance regression. KDA: 80/80 tests pass. AI-assisted (Claude Code)
1 parent 55ba155 commit 9d01f07

File tree

11 files changed

+3628
-598
lines changed

11 files changed

+3628
-598
lines changed

benchmarks/bench_gdn_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2695,7 +2695,7 @@ def main():
26952695
parser.add_argument("--num-q-heads", type=int, default=16)
26962696
parser.add_argument("--num-k-heads", type=int, default=16)
26972697
parser.add_argument("--num-v-heads", type=int, default=32)
2698-
parser.add_argument("--head-size", type=int, default=128)
2698+
parser.add_argument("--head-size", type=int, default=128, choices=[64, 128])
26992699
parser.add_argument(
27002700
"--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16"
27012701
)

benchmarks/bench_kda_decode.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
KDA (Key-Driven Attention) Decode Benchmark
19+
20+
Benchmarks the KDA CuTe DSL decode kernel with per-K-dimension gating.
21+
KDA differs from GDN by having gate g[B, T, HV, K] instead of a scalar gate.
22+
23+
Usage:
24+
python benchmarks/bench_kda_decode.py --batch-size 1 4 16 64 128 256
25+
python benchmarks/bench_kda_decode.py --head-size 64 --batch-size 1 32 128
26+
python benchmarks/bench_kda_decode.py --seq-len 1 2 3 4 --batch-size 1 32
27+
"""
28+
29+
import argparse
30+
import numpy as np
31+
import torch
32+
33+
from flashinfer.testing import bench_gpu_time
34+
35+
# Import the KDA decode kernel
36+
try:
37+
from flashinfer.kda_kernels.kda_decode_bf16_state import (
38+
kda_gated_delta_rule as kda_decode,
39+
)
40+
41+
KDA_DECODE_AVAILABLE = True
42+
except ImportError:
43+
KDA_DECODE_AVAILABLE = False
44+
45+
46+
# ============================================================================
47+
# FLOPs and Bytes Calculation
48+
# ============================================================================
49+
50+
51+
def kda_decode_flops(
52+
batch_size: int,
53+
num_q_heads: int,
54+
num_k_heads: int,
55+
num_v_heads: int,
56+
head_size: int,
57+
seq_len: int = 1,
58+
) -> int:
59+
"""
60+
Calculate FLOPs for KDA (Key-Driven Attention) decode.
61+
62+
8 * K * V FLOPs per token per head:
63+
1. k @ state (prediction): 2 * K * V
64+
2. k^T @ v_new (update): 2 * K * V
65+
3. q @ state (output): 2 * K * V
66+
4. Per-K gate application: 2 * K * V (K*V element-wise multiply + K exp() calls)
67+
68+
Note: K = V = head_size for KDA.
69+
"""
70+
num_o_heads = max(num_q_heads, num_v_heads)
71+
total_flops = 8 * seq_len * batch_size * num_o_heads * head_size * head_size
72+
return total_flops
73+
74+
75+
def kda_decode_bytes(
76+
batch_size: int,
77+
num_q_heads: int,
78+
num_k_heads: int,
79+
num_v_heads: int,
80+
head_size: int,
81+
dtype: torch.dtype,
82+
seq_len: int = 1,
83+
) -> int:
84+
"""
85+
Calculate memory bytes for KDA decode.
86+
87+
Includes:
88+
- Q, K, V tensors: [B, T, H, K] - dtype
89+
- G tensor (per-K gate): [B, T, HV, K] - dtype (extra vs GDN)
90+
- Beta: [B, T, HV] - dtype
91+
- State (read + write): [B, HV, V, K] - bf16 (2 bytes)
92+
- Output: [B, T, HV, V] - dtype
93+
"""
94+
num_o_heads = max(num_q_heads, num_v_heads)
95+
elem_size = dtype.itemsize
96+
state_dtype_bytes = 2 # BF16 state
97+
98+
# Input tensors
99+
q_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size
100+
k_bytes = batch_size * seq_len * num_k_heads * head_size * elem_size
101+
v_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size
102+
103+
# Per-K gate: [B, T, HV, K] - the extra input vs GDN
104+
g_bytes = batch_size * seq_len * num_o_heads * head_size * elem_size
105+
106+
# Beta: [B, T, HV]
107+
beta_bytes = batch_size * seq_len * num_o_heads * elem_size
108+
109+
# Output: [B, T, HV, V]
110+
o_bytes = batch_size * seq_len * num_o_heads * head_size * elem_size
111+
112+
# State: [B, HV, V, K] read + write
113+
state_bytes = (
114+
2 * batch_size * num_o_heads * head_size * head_size * state_dtype_bytes
115+
)
116+
117+
total_bytes = (
118+
q_bytes + k_bytes + v_bytes + g_bytes + beta_bytes + o_bytes + state_bytes
119+
)
120+
return total_bytes
121+
122+
123+
# ============================================================================
124+
# Benchmark Function
125+
# ============================================================================
126+
127+
128+
def bench_kda_decode(
129+
batch_size: int,
130+
seq_len: int,
131+
num_q_heads: int,
132+
num_k_heads: int,
133+
num_v_heads: int,
134+
head_size: int,
135+
dtype: torch.dtype,
136+
warmup_iters: int = 10,
137+
bench_iters: int = 100,
138+
):
139+
"""Benchmark KDA decode kernel for T=1,2,3,4."""
140+
if not KDA_DECODE_AVAILABLE:
141+
raise RuntimeError("KDA decode kernel is not available")
142+
143+
assert seq_len in [1, 2, 3, 4], f"KDA decode supports T=1,2,3,4, got T={seq_len}"
144+
145+
num_o_heads = max(num_q_heads, num_v_heads)
146+
147+
# Create inputs
148+
T = seq_len
149+
q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda")
150+
k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda")
151+
v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda")
152+
153+
# KDA-specific: per-K log-space gate [B, T, HV, K]
154+
g = torch.randn(batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda")
155+
156+
# Beta: [B, T, HV] (pre-sigmoided)
157+
beta = torch.randn(batch_size, T, num_o_heads, dtype=dtype, device="cuda")
158+
159+
# Initial state: [B, HV, V, K] (K-last layout, BF16)
160+
state = torch.randn(
161+
batch_size,
162+
num_o_heads,
163+
head_size,
164+
head_size,
165+
dtype=torch.bfloat16,
166+
device="cuda",
167+
)
168+
169+
# Scale factor
170+
scale = 1.0 / (head_size**0.5)
171+
172+
# Benchmark with bench_gpu_time (CUPTI for accurate kernel timing)
173+
kernel_times_ms = bench_gpu_time(
174+
lambda: kda_decode(
175+
q=q,
176+
k=k,
177+
v=v,
178+
g=g,
179+
beta=beta,
180+
initial_state_source=state,
181+
scale=scale,
182+
use_qk_l2norm_in_kernel=True,
183+
),
184+
enable_cupti=True,
185+
dry_run_iters=warmup_iters,
186+
repeat_iters=bench_iters,
187+
)
188+
189+
# Calculate metrics
190+
kernel_median_ms = np.median(kernel_times_ms)
191+
flops = kda_decode_flops(
192+
batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len
193+
)
194+
bytes_accessed = kda_decode_bytes(
195+
batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, dtype, seq_len
196+
)
197+
198+
kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0
199+
kernel_tb_per_sec = (
200+
bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0
201+
)
202+
203+
return {
204+
"batch_size": batch_size,
205+
"seq_len": seq_len,
206+
"kernel_median_us": kernel_median_ms * 1000,
207+
"kernel_tflops": kernel_tflops,
208+
"kernel_tb_per_sec": kernel_tb_per_sec,
209+
}
210+
211+
212+
# ============================================================================
213+
# Runner
214+
# ============================================================================
215+
216+
217+
def run_kda_decode_benchmark(args, dtype):
218+
"""Run KDA decode benchmark for T=1,2,3,4."""
219+
if not KDA_DECODE_AVAILABLE:
220+
print("Error: KDA decode kernel is not available.")
221+
print("Make sure flashinfer.kda_kernels.kda_decode_bf16_state is importable.")
222+
return
223+
224+
# Filter seq_len to only valid values (1,2,3,4)
225+
valid_seq_lens = [t for t in args.seq_len if t in [1, 2, 3, 4]]
226+
if not valid_seq_lens:
227+
print("Error: --seq-len must include values from [1, 2, 3, 4]")
228+
return
229+
230+
print("\n" + "=" * 100)
231+
print(f"KDA Decode Benchmark (T={valid_seq_lens})")
232+
print(
233+
f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, "
234+
f"v_heads={args.num_v_heads}, head_size={args.head_size}, "
235+
f"dtype={args.dtype}"
236+
)
237+
print("=" * 100)
238+
print()
239+
print(f"{'batch':>6} {'T':>4} {'time(us)':>10} {'TFLOPS':>10} {'TB/s':>10}")
240+
print("-" * 100)
241+
242+
all_results = []
243+
for batch_size in args.batch_size:
244+
for seq_len in valid_seq_lens:
245+
try:
246+
result = bench_kda_decode(
247+
batch_size=batch_size,
248+
seq_len=seq_len,
249+
num_q_heads=args.num_q_heads,
250+
num_k_heads=args.num_k_heads,
251+
num_v_heads=args.num_v_heads,
252+
head_size=args.head_size,
253+
dtype=dtype,
254+
warmup_iters=args.warmup,
255+
bench_iters=args.iters,
256+
)
257+
all_results.append(result)
258+
259+
print(
260+
f"{result['batch_size']:>6} {result['seq_len']:>4} "
261+
f"{result['kernel_median_us']:>10.2f} "
262+
f"{result['kernel_tflops']:>10.2f} "
263+
f"{result['kernel_tb_per_sec']:>10.2f}"
264+
)
265+
except Exception as e:
266+
print(
267+
f"{batch_size:>6} {seq_len:>4} {'ERROR':>10} - {type(e).__name__}: {e}"
268+
)
269+
270+
print("-" * 100)
271+
print()
272+
273+
# Summary by T value
274+
for t in valid_seq_lens:
275+
t_results = [r for r in all_results if r["seq_len"] == t]
276+
if t_results:
277+
avg_time = np.mean([r["kernel_median_us"] for r in t_results])
278+
avg_tflops = np.mean([r["kernel_tflops"] for r in t_results])
279+
print(
280+
f"T={t}: Average time={avg_time:.2f}us, Average TFLOPS={avg_tflops:.2f}"
281+
)
282+
283+
284+
# ============================================================================
285+
# Main
286+
# ============================================================================
287+
288+
289+
def main():
290+
parser = argparse.ArgumentParser(
291+
description="KDA Decode Benchmark",
292+
formatter_class=argparse.RawDescriptionHelpFormatter,
293+
epilog="""
294+
Examples:
295+
python benchmarks/bench_kda_decode.py --batch-size 1 4 16 64 128 256
296+
python benchmarks/bench_kda_decode.py --head-size 64 --batch-size 1 32 128
297+
python benchmarks/bench_kda_decode.py --seq-len 1 2 3 4 --batch-size 1 32
298+
""",
299+
)
300+
parser.add_argument(
301+
"--batch-size",
302+
type=int,
303+
nargs="+",
304+
default=[1, 4, 16, 64, 128, 256],
305+
help="Batch sizes to benchmark",
306+
)
307+
parser.add_argument("--num-q-heads", type=int, default=16)
308+
parser.add_argument("--num-k-heads", type=int, default=16)
309+
parser.add_argument("--num-v-heads", type=int, default=32)
310+
parser.add_argument("--head-size", type=int, default=128, choices=[64, 128])
311+
parser.add_argument(
312+
"--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16"
313+
)
314+
parser.add_argument(
315+
"--seq-len",
316+
type=int,
317+
nargs="+",
318+
default=[1, 2, 3, 4],
319+
help="Sequence lengths (T=1,2,3,4)",
320+
)
321+
parser.add_argument(
322+
"--warmup",
323+
type=int,
324+
default=10,
325+
help="Number of warmup iterations",
326+
)
327+
parser.add_argument(
328+
"--iters",
329+
type=int,
330+
default=100,
331+
help="Number of benchmark iterations",
332+
)
333+
args = parser.parse_args()
334+
335+
# Resolve dtype
336+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16}
337+
dtype = dtype_map[args.dtype]
338+
339+
run_kda_decode_benchmark(args, dtype)
340+
341+
342+
if __name__ == "__main__":
343+
main()

flashinfer/gdn_decode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,13 +1004,13 @@ def gated_delta_rule_decode_pretranspose(
10041004
f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}"
10051005
)
10061006

1007-
# Backend: gdn_decode_klast_bf16_state when bf16 state, T<=4, K-last layout, K=V=128
1007+
# Backend: gdn_decode_klast_bf16_state when bf16 state, T<=4, K-last layout, K=V in {64,128}
10081008
use_gdn_decode_klast_bf16_state = (
10091009
_GDN_DECODE_KLAST_BF16_STATE_AVAILABLE
10101010
and state.dtype == torch.bfloat16
10111011
and T in (1, 2, 3, 4)
1012-
and K == 128
1013-
and V == 128
1012+
and K == V
1013+
and K in (64, 128)
10141014
)
10151015
if use_gdn_decode_klast_bf16_state:
10161016
assert q.dtype in (torch.float16, torch.bfloat16), (

0 commit comments

Comments
 (0)