|
| 1 | +""" |
| 2 | +Copyright (c) 2025 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +""" |
| 18 | +KDA (Key-Driven Attention) Decode Benchmark |
| 19 | +
|
| 20 | +Benchmarks the KDA CuTe DSL decode kernel with per-K-dimension gating. |
| 21 | +KDA differs from GDN by having gate g[B, T, HV, K] instead of a scalar gate. |
| 22 | +
|
| 23 | +Usage: |
| 24 | + python benchmarks/bench_kda_decode.py --batch-size 1 4 16 64 128 256 |
| 25 | + python benchmarks/bench_kda_decode.py --head-size 64 --batch-size 1 32 128 |
| 26 | + python benchmarks/bench_kda_decode.py --seq-len 1 2 3 4 --batch-size 1 32 |
| 27 | +""" |
| 28 | + |
| 29 | +import argparse |
| 30 | +import numpy as np |
| 31 | +import torch |
| 32 | + |
| 33 | +from flashinfer.testing import bench_gpu_time |
| 34 | + |
| 35 | +# Import the KDA decode kernel |
| 36 | +try: |
| 37 | + from flashinfer.kda_kernels.kda_decode_bf16_state import ( |
| 38 | + kda_gated_delta_rule as kda_decode, |
| 39 | + ) |
| 40 | + |
| 41 | + KDA_DECODE_AVAILABLE = True |
| 42 | +except ImportError: |
| 43 | + KDA_DECODE_AVAILABLE = False |
| 44 | + |
| 45 | + |
| 46 | +# ============================================================================ |
| 47 | +# FLOPs and Bytes Calculation |
| 48 | +# ============================================================================ |
| 49 | + |
| 50 | + |
| 51 | +def kda_decode_flops( |
| 52 | + batch_size: int, |
| 53 | + num_q_heads: int, |
| 54 | + num_k_heads: int, |
| 55 | + num_v_heads: int, |
| 56 | + head_size: int, |
| 57 | + seq_len: int = 1, |
| 58 | +) -> int: |
| 59 | + """ |
| 60 | + Calculate FLOPs for KDA (Key-Driven Attention) decode. |
| 61 | +
|
| 62 | + 8 * K * V FLOPs per token per head: |
| 63 | + 1. k @ state (prediction): 2 * K * V |
| 64 | + 2. k^T @ v_new (update): 2 * K * V |
| 65 | + 3. q @ state (output): 2 * K * V |
| 66 | + 4. Per-K gate application: 2 * K * V (K*V element-wise multiply + K exp() calls) |
| 67 | +
|
| 68 | + Note: K = V = head_size for KDA. |
| 69 | + """ |
| 70 | + num_o_heads = max(num_q_heads, num_v_heads) |
| 71 | + total_flops = 8 * seq_len * batch_size * num_o_heads * head_size * head_size |
| 72 | + return total_flops |
| 73 | + |
| 74 | + |
| 75 | +def kda_decode_bytes( |
| 76 | + batch_size: int, |
| 77 | + num_q_heads: int, |
| 78 | + num_k_heads: int, |
| 79 | + num_v_heads: int, |
| 80 | + head_size: int, |
| 81 | + dtype: torch.dtype, |
| 82 | + seq_len: int = 1, |
| 83 | +) -> int: |
| 84 | + """ |
| 85 | + Calculate memory bytes for KDA decode. |
| 86 | +
|
| 87 | + Includes: |
| 88 | + - Q, K, V tensors: [B, T, H, K] - dtype |
| 89 | + - G tensor (per-K gate): [B, T, HV, K] - dtype (extra vs GDN) |
| 90 | + - Beta: [B, T, HV] - dtype |
| 91 | + - State (read + write): [B, HV, V, K] - bf16 (2 bytes) |
| 92 | + - Output: [B, T, HV, V] - dtype |
| 93 | + """ |
| 94 | + num_o_heads = max(num_q_heads, num_v_heads) |
| 95 | + elem_size = dtype.itemsize |
| 96 | + state_dtype_bytes = 2 # BF16 state |
| 97 | + |
| 98 | + # Input tensors |
| 99 | + q_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size |
| 100 | + k_bytes = batch_size * seq_len * num_k_heads * head_size * elem_size |
| 101 | + v_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size |
| 102 | + |
| 103 | + # Per-K gate: [B, T, HV, K] - the extra input vs GDN |
| 104 | + g_bytes = batch_size * seq_len * num_o_heads * head_size * elem_size |
| 105 | + |
| 106 | + # Beta: [B, T, HV] |
| 107 | + beta_bytes = batch_size * seq_len * num_o_heads * elem_size |
| 108 | + |
| 109 | + # Output: [B, T, HV, V] |
| 110 | + o_bytes = batch_size * seq_len * num_o_heads * head_size * elem_size |
| 111 | + |
| 112 | + # State: [B, HV, V, K] read + write |
| 113 | + state_bytes = ( |
| 114 | + 2 * batch_size * num_o_heads * head_size * head_size * state_dtype_bytes |
| 115 | + ) |
| 116 | + |
| 117 | + total_bytes = ( |
| 118 | + q_bytes + k_bytes + v_bytes + g_bytes + beta_bytes + o_bytes + state_bytes |
| 119 | + ) |
| 120 | + return total_bytes |
| 121 | + |
| 122 | + |
| 123 | +# ============================================================================ |
| 124 | +# Benchmark Function |
| 125 | +# ============================================================================ |
| 126 | + |
| 127 | + |
| 128 | +def bench_kda_decode( |
| 129 | + batch_size: int, |
| 130 | + seq_len: int, |
| 131 | + num_q_heads: int, |
| 132 | + num_k_heads: int, |
| 133 | + num_v_heads: int, |
| 134 | + head_size: int, |
| 135 | + dtype: torch.dtype, |
| 136 | + warmup_iters: int = 10, |
| 137 | + bench_iters: int = 100, |
| 138 | +): |
| 139 | + """Benchmark KDA decode kernel for T=1,2,3,4.""" |
| 140 | + if not KDA_DECODE_AVAILABLE: |
| 141 | + raise RuntimeError("KDA decode kernel is not available") |
| 142 | + |
| 143 | + assert seq_len in [1, 2, 3, 4], f"KDA decode supports T=1,2,3,4, got T={seq_len}" |
| 144 | + |
| 145 | + num_o_heads = max(num_q_heads, num_v_heads) |
| 146 | + |
| 147 | + # Create inputs |
| 148 | + T = seq_len |
| 149 | + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") |
| 150 | + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") |
| 151 | + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") |
| 152 | + |
| 153 | + # KDA-specific: per-K log-space gate [B, T, HV, K] |
| 154 | + g = torch.randn(batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda") |
| 155 | + |
| 156 | + # Beta: [B, T, HV] (pre-sigmoided) |
| 157 | + beta = torch.randn(batch_size, T, num_o_heads, dtype=dtype, device="cuda") |
| 158 | + |
| 159 | + # Initial state: [B, HV, V, K] (K-last layout, BF16) |
| 160 | + state = torch.randn( |
| 161 | + batch_size, |
| 162 | + num_o_heads, |
| 163 | + head_size, |
| 164 | + head_size, |
| 165 | + dtype=torch.bfloat16, |
| 166 | + device="cuda", |
| 167 | + ) |
| 168 | + |
| 169 | + # Scale factor |
| 170 | + scale = 1.0 / (head_size**0.5) |
| 171 | + |
| 172 | + # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) |
| 173 | + kernel_times_ms = bench_gpu_time( |
| 174 | + lambda: kda_decode( |
| 175 | + q=q, |
| 176 | + k=k, |
| 177 | + v=v, |
| 178 | + g=g, |
| 179 | + beta=beta, |
| 180 | + initial_state_source=state, |
| 181 | + scale=scale, |
| 182 | + use_qk_l2norm_in_kernel=True, |
| 183 | + ), |
| 184 | + enable_cupti=True, |
| 185 | + dry_run_iters=warmup_iters, |
| 186 | + repeat_iters=bench_iters, |
| 187 | + ) |
| 188 | + |
| 189 | + # Calculate metrics |
| 190 | + kernel_median_ms = np.median(kernel_times_ms) |
| 191 | + flops = kda_decode_flops( |
| 192 | + batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len |
| 193 | + ) |
| 194 | + bytes_accessed = kda_decode_bytes( |
| 195 | + batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, dtype, seq_len |
| 196 | + ) |
| 197 | + |
| 198 | + kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 |
| 199 | + kernel_tb_per_sec = ( |
| 200 | + bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 |
| 201 | + ) |
| 202 | + |
| 203 | + return { |
| 204 | + "batch_size": batch_size, |
| 205 | + "seq_len": seq_len, |
| 206 | + "kernel_median_us": kernel_median_ms * 1000, |
| 207 | + "kernel_tflops": kernel_tflops, |
| 208 | + "kernel_tb_per_sec": kernel_tb_per_sec, |
| 209 | + } |
| 210 | + |
| 211 | + |
| 212 | +# ============================================================================ |
| 213 | +# Runner |
| 214 | +# ============================================================================ |
| 215 | + |
| 216 | + |
| 217 | +def run_kda_decode_benchmark(args, dtype): |
| 218 | + """Run KDA decode benchmark for T=1,2,3,4.""" |
| 219 | + if not KDA_DECODE_AVAILABLE: |
| 220 | + print("Error: KDA decode kernel is not available.") |
| 221 | + print("Make sure flashinfer.kda_kernels.kda_decode_bf16_state is importable.") |
| 222 | + return |
| 223 | + |
| 224 | + # Filter seq_len to only valid values (1,2,3,4) |
| 225 | + valid_seq_lens = [t for t in args.seq_len if t in [1, 2, 3, 4]] |
| 226 | + if not valid_seq_lens: |
| 227 | + print("Error: --seq-len must include values from [1, 2, 3, 4]") |
| 228 | + return |
| 229 | + |
| 230 | + print("\n" + "=" * 100) |
| 231 | + print(f"KDA Decode Benchmark (T={valid_seq_lens})") |
| 232 | + print( |
| 233 | + f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " |
| 234 | + f"v_heads={args.num_v_heads}, head_size={args.head_size}, " |
| 235 | + f"dtype={args.dtype}" |
| 236 | + ) |
| 237 | + print("=" * 100) |
| 238 | + print() |
| 239 | + print(f"{'batch':>6} {'T':>4} {'time(us)':>10} {'TFLOPS':>10} {'TB/s':>10}") |
| 240 | + print("-" * 100) |
| 241 | + |
| 242 | + all_results = [] |
| 243 | + for batch_size in args.batch_size: |
| 244 | + for seq_len in valid_seq_lens: |
| 245 | + try: |
| 246 | + result = bench_kda_decode( |
| 247 | + batch_size=batch_size, |
| 248 | + seq_len=seq_len, |
| 249 | + num_q_heads=args.num_q_heads, |
| 250 | + num_k_heads=args.num_k_heads, |
| 251 | + num_v_heads=args.num_v_heads, |
| 252 | + head_size=args.head_size, |
| 253 | + dtype=dtype, |
| 254 | + warmup_iters=args.warmup, |
| 255 | + bench_iters=args.iters, |
| 256 | + ) |
| 257 | + all_results.append(result) |
| 258 | + |
| 259 | + print( |
| 260 | + f"{result['batch_size']:>6} {result['seq_len']:>4} " |
| 261 | + f"{result['kernel_median_us']:>10.2f} " |
| 262 | + f"{result['kernel_tflops']:>10.2f} " |
| 263 | + f"{result['kernel_tb_per_sec']:>10.2f}" |
| 264 | + ) |
| 265 | + except Exception as e: |
| 266 | + print( |
| 267 | + f"{batch_size:>6} {seq_len:>4} {'ERROR':>10} - {type(e).__name__}: {e}" |
| 268 | + ) |
| 269 | + |
| 270 | + print("-" * 100) |
| 271 | + print() |
| 272 | + |
| 273 | + # Summary by T value |
| 274 | + for t in valid_seq_lens: |
| 275 | + t_results = [r for r in all_results if r["seq_len"] == t] |
| 276 | + if t_results: |
| 277 | + avg_time = np.mean([r["kernel_median_us"] for r in t_results]) |
| 278 | + avg_tflops = np.mean([r["kernel_tflops"] for r in t_results]) |
| 279 | + print( |
| 280 | + f"T={t}: Average time={avg_time:.2f}us, Average TFLOPS={avg_tflops:.2f}" |
| 281 | + ) |
| 282 | + |
| 283 | + |
| 284 | +# ============================================================================ |
| 285 | +# Main |
| 286 | +# ============================================================================ |
| 287 | + |
| 288 | + |
| 289 | +def main(): |
| 290 | + parser = argparse.ArgumentParser( |
| 291 | + description="KDA Decode Benchmark", |
| 292 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 293 | + epilog=""" |
| 294 | +Examples: |
| 295 | + python benchmarks/bench_kda_decode.py --batch-size 1 4 16 64 128 256 |
| 296 | + python benchmarks/bench_kda_decode.py --head-size 64 --batch-size 1 32 128 |
| 297 | + python benchmarks/bench_kda_decode.py --seq-len 1 2 3 4 --batch-size 1 32 |
| 298 | +""", |
| 299 | + ) |
| 300 | + parser.add_argument( |
| 301 | + "--batch-size", |
| 302 | + type=int, |
| 303 | + nargs="+", |
| 304 | + default=[1, 4, 16, 64, 128, 256], |
| 305 | + help="Batch sizes to benchmark", |
| 306 | + ) |
| 307 | + parser.add_argument("--num-q-heads", type=int, default=16) |
| 308 | + parser.add_argument("--num-k-heads", type=int, default=16) |
| 309 | + parser.add_argument("--num-v-heads", type=int, default=32) |
| 310 | + parser.add_argument("--head-size", type=int, default=128, choices=[64, 128]) |
| 311 | + parser.add_argument( |
| 312 | + "--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16" |
| 313 | + ) |
| 314 | + parser.add_argument( |
| 315 | + "--seq-len", |
| 316 | + type=int, |
| 317 | + nargs="+", |
| 318 | + default=[1, 2, 3, 4], |
| 319 | + help="Sequence lengths (T=1,2,3,4)", |
| 320 | + ) |
| 321 | + parser.add_argument( |
| 322 | + "--warmup", |
| 323 | + type=int, |
| 324 | + default=10, |
| 325 | + help="Number of warmup iterations", |
| 326 | + ) |
| 327 | + parser.add_argument( |
| 328 | + "--iters", |
| 329 | + type=int, |
| 330 | + default=100, |
| 331 | + help="Number of benchmark iterations", |
| 332 | + ) |
| 333 | + args = parser.parse_args() |
| 334 | + |
| 335 | + # Resolve dtype |
| 336 | + dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16} |
| 337 | + dtype = dtype_map[args.dtype] |
| 338 | + |
| 339 | + run_kda_decode_benchmark(args, dtype) |
| 340 | + |
| 341 | + |
| 342 | +if __name__ == "__main__": |
| 343 | + main() |
0 commit comments