TurboQuant: KV cache quantization with Hadamard transform (TBQ3_0 / TBQ4_0)#115
Draft
jesusmb1995 wants to merge 20 commits intotetherto:temp-7248from
Draft
TurboQuant: KV cache quantization with Hadamard transform (TBQ3_0 / TBQ4_0)#115jesusmb1995 wants to merge 20 commits intotetherto:temp-7248from
jesusmb1995 wants to merge 20 commits intotetherto:temp-7248from
Conversation
Implements the MSE-optimal stage of TurboQuant (Zandieh et al., ICLR 2026) for CPU. Compresses KV cache vectors to 3.25 or 4.25 bits per value using randomized rotation + Lloyd-Max scalar quantization on each coordinate. New GGML types: GGML_TYPE_TQ3_0 (8 centroids, 3-bit indices) and GGML_TYPE_TQ4_0 (16 centroids, 4-bit indices), block size 128 (= head_dim). Includes block structs, quantize/dequantize, vec_dot via dequantize path, type registration, CLI argument support, and head_dim != 128 validation.
Replace per-type if/else chains with a data-driven table mapping each GGML type to its expected quantization and dot product error thresholds. Makes adding new quantization types cleaner.
Verify that TQ4 error < TQ3 error (more bits = less error) and that TQ3 error < TQ1/TQ2 error (TurboQuant beats ternary quantization), catching codebook or implementation regressions.
Author
Shell script that runs perplexity across cache types (tq3_0, tq4_0, q4_0, q8_0, f16) on a real model and checks for PPL regressions. Auto-downloads model and wikitext-2 dataset. Validates tq4_0 PPL <= tq3_0 PPL.
Replace the O(d²) Gram-Schmidt rotation matrix with a randomized Hadamard transform: (1/√d) · H · D, where H is the Walsh-Hadamard matrix applied via Fast Hadamard Transform and D is a random ±1 diagonal. ~18x fewer ops per block (896 vs 16,384 for d=128). Also fixes the Lloyd-Max codebooks to match the correct Beta(d=128) distribution computed by scripts/compute_tq_codebooks.py. Removes the block=64 variant API (types not yet fully wired up) and adapts the quantize/dequantize helpers to use ggml_half for norm storage.
Add block=64 variants (TQ3_0_64, TQ4_0_64) for models with head_dim=64 (Llama-3.2-1B/3B). The user specifies tq3_0/tq4_0 on the CLI and the KV cache init auto-selects the d=64 internal type when head_dim=64. Includes d=64 Lloyd-Max codebooks computed from the exact Beta(d=64) PDF, separate Hadamard sign arrays (seed 43), and full CPU wiring: type_traits, quantize/dequantize functions, vec_dot, ops.cpp dispatch lists, and the test-quantize-fns thresholds. Note: d=64 quality is significantly worse than d=128 (+31% TQ3 / +7% TQ4 PPL regression vs f16) due to the wider coordinate distribution. QJL Stage 2 would help but is not yet implemented.
Add Vulkan compute shaders for TQ3_0/TQ4_0 dequantization, quantization (copy-to-quant), and SET_ROWS. The dequant shaders use 128 threads per workgroup with shared-memory Fast Hadamard Transform (7 butterfly stages). The quantize shaders (copy_to_quant.comp) implement the forward path: norm computation, sign application, FHT, codebook nearest-centroid search, and bit-packing — all in a single thread per block. Registers dequant, CPY f32→TQ, and SET_ROWS pipelines in ggml-vulkan.cpp. Updates supports_op for MUL_MAT, CPY, and SET_ROWS. Block type definitions added to types.glsl and codebook/sign constants to tq_utils.comp. V cache auto-downgrades to f16 on GPU since Flash Attention shaders cannot perform inline FHT dequantization for TQ types.
Replace the two-pass dequant→fp16→matmul fallback with fused mul_mat_vec shaders that dequantize and compute the dot product in a single GPU dispatch. Each workgroup (128 threads) processes one output row, iterating over TQ blocks with shared-memory FHT + dot accumulation. Uses subgroup shuffle (GL_KHR_shader_subgroup_shuffle) for the first log2(subgroup_size) FHT stages (no barriers), falling back to shared memory only for the remaining 2 stages (subgroup_size=32). This reduces per-block barrier count from 8 to 3. GPU utilization is improved but still ~50% vs 100% for native types (q4_0, q8_0). The remaining gap is due to the FHT synchronization overhead and 128-thread workgroup size. Further optimization would require fused Flash Attention with pre-rotated queries (rotate Q once, dot with raw codebook values — no FHT in the hot loop).
Move the Hadamard transform from inside the quantize/dequantize kernels to the graph level as an explicit MUL_MAT on Q, K, V. Since both Q and K are rotated by the same orthogonal matrix H, the rotations cancel in the attention dot product: Q·K^T = (H·Q)·(H·K)^T = Q·K^T. This allows the quantize/dequantize kernels to become pure codebook operations with no FHT, no sign flips, no shared memory barriers: - llama-graph.cpp: add ggml_rotate_hadamard() applied to Q/K/V when KV cache uses a quantized type - copy_to_quant.comp (SET_ROWS): remove tq_fht_local(), sign flips, inv_sqrt_d. Add binary-search quantize (3 comparisons for TQ3, 4 for TQ4) and packed 3-bit writes (8 indices per uint32) - dequant_tq3_0.comp, dequant_tq4_0.comp: remove inverse FHT, shared memory, sign flips. Now just codebook lookup + scale - mul_mat_vec_tq3_0.comp, mul_mat_vec_tq4_0.comp: remove subgroup shuffle FHT and shared memory FHT stages - ggml-quants.c: remove tq_forward_inplace() from quantize, tq_inverse_inplace() from dequantize SET_ROWS: 88 us -> 19 us (4.6x faster). Graph-level rotation adds <2% overhead.
Add missing Vulkan pipeline support for TQ3_0/TQ4_0 batch operations: - dequant_funcs.glsl: add TQ3_0/TQ4_0 dequant() and dequant4() for mul_mm batch matmul path (3-bit unpack + codebook lookup for TQ3, nibble unpack + codebook lookup for TQ4) - vulkan-shaders-gen.cpp: remove tq3_0/tq4_0 skip in matmul_shaders() and copy_from_quant generation - ggml-vulkan.cpp: add pipeline_dequant_mul_mat_mat registrations, pipeline_cpy_quant_f32 registrations (TQ3/TQ4 -> f32 GPU dequant), and dispatch routing in ggml_vk_get_cpy_pipeline switch statements Prevents CPU fallback for copy/dequant operations when reading from TQ3/TQ4 KV cache on GPU.
Add inline codebook dequantization to the Vulkan Flash Attention shader, enabling FA for TurboQuant KV cache types. Since optRot moved the Hadamard rotation to the graph level, the FA shader only needs a simple codebook lookup (same complexity as Q4_0's linear dequant). - flash_attn_base.glsl: add dequantize4() for TQ3_0 (3-bit unpack + TQ3_CB lookup) and TQ4_0 (nibble unpack + TQ4_CB lookup) with separate K/V buffer bindings - vulkan-shaders-gen.cpp: remove tq3_0/tq4_0 skip in FA shader generation, add to scalar FA path - ggml-vulkan.cpp: add CREATE_FA for TQ3_0/TQ4_0, add to supports_op FLASH_ATTN_EXT switch - llama-context.cpp: remove stale V-cache f16 downgrade workaround (was: "TurboQuant V cache not supported with GPU Flash Attention") Result: graph splits drop from 66 to 2, prompt eval goes from 29 tok/s to 127 tok/s (4.4x speedup). TQ3 is now within 5% of f16 performance.
Useful for models with head dimensions != 128.
Extend test-quantize-fns to run quantization round-trip (RMSE) and dot product (mul_mat) checks on any ggml backend, not just CPU. A new `-b <backend>` flag (e.g. `-b vulkan`) selects the backend; without it the test runs the original CPU-only path. The backend path builds small ggml compute graphs (ggml_cpy for round-trip, ggml_mul_mat for dot product) and executes them through ggml_backend_sched, with a CPU fallback backend for ops the primary backend does not support. Gracefully skips types whose cpy or mul_mat ops are unsupported on the chosen backend, and adds verbose debug output (`-v`) showing tensor shapes, backend support decisions, and raw GPU vs reference results to aid debugging backend-specific numerical issues.
The TQ3_0 and TQ4_0 mul_mat_vec pipelines require BLOCK_SIZE=128 (matching QUANT_K), but were compiled with SHADER_REDUCTION_MODE_SUBGROUP which only reduces within a single subgroup via subgroupAdd — with no cross-subgroup shared-memory reduction step. On GPUs where subgroup_size < 128 (e.g. NVIDIA with warp size 32), this caused only 1/N of the partial sums to be accumulated, where N = BLOCK_SIZE / subgroup_size. For a typical NVIDIA GPU the dot product result was exactly 1/4 of the correct value, producing ~0.82 normalised error in test-quantize-fns. Fix by selecting SHADER_REDUCTION_MODE_HYBRID for TQ3/TQ4 pipelines in both f32 and f16 variants, which performs subgroupAdd within each subgroup then sums across subgroups through shared memory.
Author
Add -b BACKEND flag to test-quantize-perf.cpp for running quantization benchmarks on GPU backends (e.g. vulkan, cuda) via the ggml scheduler. When -b is specified, the tool builds ggml compute graphs and benchmarks three operations through the backend scheduler: - quantize: cpy f32 -> quant (SET_ROWS equivalent) - dequantize: cpy f32 -> quant -> f32 (roundtrip) - mul_mat: quantized matmul (vec_dot equivalent) Each reports min/avg time and throughput. Ops not supported by the backend are gracefully skipped. Without -b, behavior is unchanged (direct CPU function pointer calls with cycle counting).
- Split the cache quantization taxonomy into PQ (stage-1 only) and TBQ (stage-1 + QJL stage-2 correction), including pq4_0/tbq4_0 coverage. - Implemented TBQ4 stage-2 QJL encode/correction in CPU and Vulkan paths, fixed scaling/sign-quantization behavior, and aligned struct/shader packing for block data. - Fixed Vulkan/Flash Attention regressions (symbol lookup, BLOCK_BYTE_SIZE alignment, shader generation, and flash-attention QJL correction wiring) so TBQ variants run end-to-end. - Updated validation and PR draft docs with correction-phase TBQ vs PQ results and notes.
Enable Flash Attention to use different quantization types for K and V caches (e.g. TBQ3_0 keys with PQ4_0 values), allowing finer memory/quality trade-offs per cache. - Refactored FA shader (flash_attn_base.glsl) to declare separate K and V buffer bindings with independent block sizes and dequantize functions (dequantize4_k / dequantize4_v), replacing the single shared dequantize4 that dispatched on a binding index. - Added backward-compat defines that map legacy DATA_A_* to DATA_K_*/DATA_V_* so existing same-type pipelines keep working unchanged. - Extended vk_fa_pipeline_state to carry v_type, enabling the pipeline cache to distinguish same-K-type entries that differ only in V type. - Added CREATE_FA_MIXED macro and ~30 mixed-type pipeline registrations covering all TBQ/PQ/Q8_0/F16 cross-pairs (scalar path). - Updated supports_op to accept mixed K/V pairs when at least one side is TBQ/PQ, falling back to FA_SCALAR automatically. - Added mixed-type shader generation in vulkan-shaders-gen.cpp for all relevant K/V combinations.
…erage - Extended test_flash_attn_ext to track distinct type_K and type_V (backward-compatible: single-type constructor still works). - Added mixed-type cases for TBQ/PQ plus Q8_0/F16 pairs in both MODE_TEST and MODE_PERF. - Updated vars() so test filtering can target type_K and type_V independently. Commands: ./build/bin/test-backend-ops -o FLASH_ATTN_EXT -p "tbq|pq" test ./build/bin/test-backend-ops -o FLASH_ATTN_EXT -p "tbq|pq" perf
Extend test-kv-cache-quantization.sh to run perplexity benchmarks with different quantization types for K and V caches, exercising the mixed-type Flash Attention paths added in the previous commit. - Added 6 mixed K/V pairs (e.g. K=tbq3_0/V=pq3_0, K=tbq4_0/V=q8_0) that run after the existing same-type tests. - Print a separate "Mixed K/V Results Summary" table with PPL, vs-f16 regression, and timing columns. - Check each mixed pair against the MAX_PPL_REGRESSION_PCT threshold.
|
Are you planning to merge this before the rebase to the latest version of llama.cpp? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements TurboQuant KV cache quantization (Zandieh et al., ICLR 2026) for CPU and Vulkan backends with full Flash Attention support. Compresses KV cache to 3.25-4.25 bits per value, enabling ~4-5x larger context windows on the same hardware.
Paper: https://arxiv.org/pdf/2504.19874
Community discussion: ggml-org#20969
Related upstream PR: ggml-org#21038 (graph-level rotation for existing quant types)
How it works
The rotation decorrelates coordinates, making each follow a known Beta distribution. A fixed codebook (no per-block calibration) achieves near-optimal MSE. The rotation cancels in the attention dot product:
Q·K^T = (H·Q)·(H·K)^T = Q·K^T.pq3_0pq4_0What's new vs community state
The upstream discussion and community implementations so far have:
This PR adds:
pq3_0,pq4_0,TBQ3_0_64,TBQ4_0_64), block structs, type traits, CPU backend registration, KV cache support via--cache-type-k/--cache-type-vscripts/compute_tq_codebooks.pypq3_0/pq4_0on the CLI; KV cache init automatically selects the d=64 internal type when head_dim=64Usage
Works transparently with both head_dim=128 (Llama-3.1, Qwen, Mistral) and head_dim=64 (Llama-3.2-1B/3B) — the right block size is auto-selected.
Perplexity results
Vulkan GPU: Mistral-7B (Q4_K_S weights, head_dim=128), wikitext-2, 4 chunks × 128 ctx
With Flash Attention enabled (GPU fully saturated, disclaimer a bit noisy need to run for longer):
q4_0, pq3_0, and pq4_0 remain in the same throughput tier as f16 under Flash Attention. Here, PQ = phase-1 / polar-style quantization (Lloyd-Max codebook + random rotation + dequant); TBQ = phase-1 + phase-2 QJL correction (the full TurboQuant path), so
tbq3_0/tbq4_0are the QJL-corrected variants.In this snapshot, TBQ variants have higher wall time than PQ/Q4_0, with
tbq4_0close in quality topq4_0after QJL correction.tbq3_0typically benefits more thantbq4_0over its PQ counterpart when the residual is meaningful, because 3-bit leaves more coarse structure for QJL to recover; at 4-bit, the codebook already preserves most structure sotbq4_0has much less clean correction headroom and measured gains are often near-noise.Mixed type support (tested on AMD integrated GPU)
Vulkan per-kernel microbenchmarks (GTX 1080 Ti, 524K values, 1000 iterations)
Isolated kernel benchmarks via
test-quantize-perf -b vulkan:GPU performance comparison (GTX 1080 Ti, 512 tokens prompt eval)
CPU: Mistral-7B (Q4_K_S weights, head_dim=128), wikitext-2, 4 chunks × 128 ctx
CPU: Llama-3.2-1B-Instruct (Q4_0 weights, head_dim=64), wikitext-2, 4 chunks × 128 ctx
Note: d=64 support is preliminary. QJL Stage 2 (not yet implemented) is critical for d=64 quality.
Vulkan performance optimization details
SET_ROWS (K-cache write quantize kernel)
Flash Attention (the critical optimization)
Without FA, TBQ3 was 4.6x slower than f16 due to 66 graph splits (each a GPU sync point) and explicit dequant-then-matmul fallback. Adding inline codebook dequant to the FA shader brought this to 5% overhead:
flash_attn_base.glsl: inlinedequantize4()for TBQ3 (3-bit unpack + codebook lookup) and TBQ4 (nibble unpack + codebook lookup)Limitations
TODOs
mul_mat_vec_*.comp/mul_mat_vec_*.comp(currently Stage 1 only; FA path is complete)MSE-only Top1%Bias and VarianceQJL Stage 2 implementation status
cpy_f32_quantmul_mat_vecTest plan
test-quantize-fns— round-trip RMSE, dot product error, cross-type invariants (CI)test-kv-cache-quantization.sh— perplexity regression check on real modelN_CHUNKS=20 N_CTX=2048test-backend-ops -o FLASH_ATTN_EXT)