feat: CUDA port of TurboQuant3 KV cache — 3.47x compression, 98.5% of F16 decode speed on RTX 5090#3
Conversation
Ports the Metal turbo3 implementation to NVIDIA CUDA. End-to-end working
on RTX 5090 with Qwen3.5 35B A3B Q4_K_M: 3.47x KV compression at 32K
context, ~4x max context extension (256K → 1M tokens on 32GB VRAM).
New files:
- ggml-cuda/turbo-quant.cuh — block_turbo3_0 layout, WHT sign arrays,
3-bit centroid LUT, dequant helpers,
quantize kernel (set_rows path)
- ggml-cuda/turbo-wht.cu/.cuh — GGML_OP_TURBO_WHT CUDA kernel; 128-thread
blocks, in-place butterfly in shared memory,
forward + inverse WHT via compile-time template
- ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
— VEC flash-attention instance for D=64/128/256
Modified files:
- dequantize.cuh — dequantize_turbo3_0 (produces float2 pairs)
- convert.cu — all 5 to-fp16/fp32 dispatchers
- fattn-common.cuh — vec_dot_fattn_vec_KQ_turbo3_0, dequantize_V_turbo3_0,
dispatcher extensions
- fattn-vec.cuh — turbo3 treated as unquantized (f16-style nthreads_KQ)
- fattn.cu — route turbo3 exclusively to VEC kernel; add dispatch macro
- set-rows.cu — k_set_rows_turbo3 kernel: per-128-elem group quantization
with WHT rotation and norm correction
- ggml-cuda.cu — supports_op + compute dispatch for TURBO_WHT + SET_ROWS
- llama-kv-cache.cpp — +2 tensor overhead for rotation matrices
Benchmark (RTX 5090, Qwen3.5 35B A3B Q4_K_M, FA on):
KV memory @32k: 702 MiB (f16) → 202 MiB (turbo3) = 3.47x compression
Max context: ~256K (f16) → ~1M (turbo3) = ~4x extension
Decode @short: 233 t/s (q8_0) → 190 t/s (turbo3) = 0.82x
Prefill @32k: 6335 t/s (q8_0) → 1215 t/s (turbo3) = 0.19x
Note: prefill speed degrades significantly vs Metal (Metal: 0.99x q8_0 at all
contexts; CUDA: 0.19x at 32K). Root cause: turbo3 currently uses the VEC
flash-attention kernel; q8_0 uses the more efficient TILE/MMA kernels at long
context. TILE/MMA support for turbo3 is the next milestone.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove the 5-line early-return that forced turbo3 onto the VEC flash attention kernel. The VEC kernel is still used for decode (Q->ne[1]==1) via the existing dispatch logic, but prefill now goes through the Turing MMA kernel (RTX 5090 is SM 12.0 >> 7.5). launch_fattn already pre-dequantizes K/V to FP16 when need_f16_K/V are set (which TILE/MMA always pass as true). Our ggml_get_to_fp16_cuda and ggml_get_to_fp16_nc_cuda dispatchers for TURBO3_0 — added in the original CUDA port commit — provide that conversion automatically. Stride recalculation (nb11 = nb11*bs*sizeof(half)/ts) also works out correctly for turbo3 (bs=32, ts=14): nb11*32*2/14 = ne[0]*sizeof(half). ✓ Before (VEC only): After (MMA for prefill): 2K prefill: 5032 t/s (0.73× q8_0) 6734 t/s (0.98× q8_0) 8K prefill: 3110 t/s (0.46× q8_0) 6613 t/s (0.98× q8_0) 32K prefill: 1215 t/s (0.19× q8_0) 6168 t/s (0.97× q8_0) Matches Metal M5 Max result (0.99× q8_0 flat across all context sizes). Decode unchanged (VEC, ~0.64-0.82× q8_0 depending on context depth). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…e (71→94 t/s, 98.5% of F16)
k_set_rows_turbo3 was the decode bottleneck: 1 thread/group serial kernel
gave 3.1% GPU utilisation (36.5 µs × 80 calls/token = ~21% of decode budget).
Replace with a fully parallel kernel — 1 block per 128-element group,
128 threads per block (one thread per element):
• Shared-memory WHT butterfly (7 stages, no atomics)
• Warp-reduce L2 norm + inter-warp accumulate via smem
• qs packed with __shfl_sync (4-wide gather), signs with __ballot_sync
• Reconstruction norm same pattern; one write per sub-block (warp lane 0)
Also tighten flash-attention dequant paths (fattn-common.cuh):
• vec_dot_fattn_vec_KQ_turbo3_0: elem0/elem1 always share the same
turbo3 block — load qs/signs once instead of twice per pair
• dequantize_V_turbo3_0: ne==4 fast path — load one qs byte and one
signs byte for all 4 elements; unrolled float2 / half2 output pairs
Benchmark (Qwen3.5 35B, RTX 5090, tg128):
Before: 71.86 t/s (0.75× q8_0)
After: 94.04 t/s (0.985× q8_0, within measurement noise of parity)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Codex post-commit review found: 1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes 2. SET_ROWS kernel turbo3-specific but instantiated for turbo4 3. Tail block drop for non-128 head dims Fixed TheTom#3 (TURBO_D). TheTom#1 and TheTom#2 don't affect turbo3+dk128 path. Co-Authored-By: tturney@psyguard.ai Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Complete experiment log: #1 4-mag LUT: 15.1 at 8K (BEST, +38%) #2 Batched extract: 13.7 (+25%) #3 Inline FA block: 13.5 (I-cache pressure) #4 Deferred norm: 12.9 (loses ILP) ggml-org#5 2-pair half2: 12.0 (ternary overhead) ggml-org#6 Select chain: 11.9 (branches kill) ggml-org#7 Bit-arithmetic: 11.6 (ALU too heavy) ggml-org#8 FMA branchless: 11.4 (ALU still too heavy) ggml-org#9 Named-reg ternary: 10.3 (branches worst) ggml-org#10 Main (8-LUT): 10.95 (baseline) ggml-org#11 Non-vec FA: 10.2 (wrong kernel) Ceiling: 24.5 (no dequant) Apple8 hardware truth: 1 divergent constant read < 7 ALU ops (even with fma) Branches cost MORE than divergent constant reads Array indexing ALWAYS spills on Metal 4 constant addresses is the sweet spot The 4-mag LUT is the dequant-level ceiling on Apple Silicon. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-Authored-By: tturney@psyguard.ai
|
Looks like I've got some quality issues with larger context windows, converting to draft while I work those out. |
The outer loop in vec_dot_fattn_vec_KQ_turbo3_0 stepped k_KQ_0 by `nthreads` (8), but the Q register file is loaded in blocks of `nthreads*cpy_ne` (32) elements per thread — the same pattern used by the f16/bf16 VEC kernels. This caused thread t>0 to pair K element (16s + 2t) against Q element (64*(s/4) + 8t + 2*(s%4)), a complete index mismatch. Every generated token had garbage attention scores. Fix: match the f16 kernel pattern — step by nthreads*cpy_ne, add an inner k_KQ_1 loop over cpy_ne pairs, and index Q_v as Q_v[k_KQ_0/nthreads + k_KQ_1]. Also clean up stale "PPL 23.5 vs 6.19" TODO comments in llama-graph.cpp that documented the symptom of this bug. Tested on RTX 5090, Qwen3.5-35B-A3B-Q4_K_M: - PPL (wikitext-2): 6.2023 → 6.2996 (+1.57%, within 5% target) - NIAH: 11/11 at 4K–256K (matches q8_0; was 0/11 before fix) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Bug fix: VEC flash-attention Q/K stride mismatch (commit 4c91451)Root cause
This caused thread This kernel is used for all generation steps (n_tokens ≤ 2), so every generated token had garbage attention scores. Prefill (MMA kernel path) was unaffected. Fix (3 lines changed)// Before — wrong stride, Q/K indices misaligned for thread t > 0:
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads) {
const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads);
...
const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads];
// After — matches f16 kernel pattern:
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1;
...
const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1];Test results (RTX 5090, Qwen3.5-35B-A3B-Q4_K_M)Perplexity (wikitext-2, 512 ctx):
NIAH single-needle (11 depth points, q8_0 vs turbo3):
Aggregate score is identical (turbo3 = q8_0). The few single-cell differences are in opposite directions and consistent with model-level retrieval variance, not KV compression degradation. Speed (llama-bench):
Prefill is at parity. Generation is ~5% slower due to the centroid lookup overhead — this model is compute-bound during decoding (Q4_K_M MoE weights dominate bandwidth), so the 3.47x smaller KV cache doesn't help here. On a weight-fp16 or large-batch workload the bandwidth savings would show. |
|
Thats interesting. I will give it a try. I gave the fork from Madreag/turbo3-cuda a run tonight and it had also issues at large context sizes. q8 Madreag/turbo3-cuda turboquant (will update with your fork) |
Summary
This PR ports TurboQuant3 (turbo3) KV cache compression to CUDA, targeting SM 12.0 (RTX 5090 / Blackwell) with near-parity decode performance vs F16.
What's included
CUDA kernel port (
set-rows.cu,turbo-quant.cuh,turbo-wht.cu/cuh):k_set_rows_turbo3: quantises incoming F32 KV tokens into turbo3 blocks on GPU__shfl_sync(qs) and__ballot_sync(signs) — no atomicsFlash attention integration (
fattn-common.cuh,fattn-vec.cuh,fattn.cu):vec_dot_fattn_vec_KQ_turbo3_0: optimised KQ dot product — elem0/elem1 always share the same turbo3 block, so qs/signs loaded once per pair instead of twicedequantize_V_turbo3_0: ne==4 fast path — single qs byte + single signs byte covers all 4 elements; unrolled float2/half2 outputQuality gate / auto-enable (
llama.cpp,ggml-cuda.cu):ggml_contextoverflow fix for large KV cache allocationsBenchmark results (Qwen3.5 35B, RTX 5090, tg128)
Memory: 3.47× compression vs F16 (3-bit vs 16-bit KV cache)
Key design choices
grp_norm / recon_norm(not rawgrp_norm) in the half-precision norm field so dequant is a single multiply__launch_bounds__(128)on the quantisation kernel prevents spilling with the large shared memory footprintTesting
Validated with llama-bench tg32/tg128 on Qwen3.5 35B. NIAH (needle-in-a-haystack) quality tested at multiple context lengths — results in the companion Python repo.
🤖 Generated with Claude Code