feat: QuantizedRotatingKVCache + KVSplit (K/V different bits)#1074
feat: QuantizedRotatingKVCache + KVSplit (K/V different bits)#1074deceptech-packet-ninja wants to merge 3 commits intoml-explore:mainfrom
Conversation
Enables KV cache quantization in mlx_lm.server, closing ml-explore#1043. Batching disabled when kv_bits is set (BatchQuantizedKVCache NYI). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- QuantizedRotatingKVCache: inherits from RotatingKVCache, stores quantized (data, scales, biases) tuples in a rotating buffer. Fixes RotatingKVCache.to_quantized() NotImplementedError. Enables --max-kv-size + --kv-bits together. - KVSplit: QuantizedKVCache and QuantizedRotatingKVCache accept bits=(key_bits, value_bits) tuples for asymmetric quantization. quantized_scaled_dot_product_attention uses separate params for key and value matmuls. Full backward compatibility. Closes ml-explore#191. Addresses ml-explore#883 (OOM from unbounded KV cache). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Thump604
left a comment
There was a problem hiding this comment.
I run Qwen3.5-122B-A10B (5-bit, ~82GB weights) on M2 Ultra 128GB in production with continuous batching and MTP. KV cache quantization is on my critical path for 1M context support.
I independently built a similar asymmetric K8/V4 cache class and tested it extensively. A few findings that may be useful:
Architecture: My implementation dequantized the full cache to FP16 before passing to SDPA, which caused a 54% latency regression at 100K context. This PR's approach of keeping K/V in quantized tuple form and routing through mx.quantized_matmul is the right architecture since dequantization is kernel-fused. Nice work.
Hadamard rotation: I tested applying a Walsh-Hadamard transform before quantization to distribute outlier energy. At K8, it provides negligible quality improvement (1.01x MSE ratio vs no rotation). At K4, it helps but K4 fails on Qwen3.5 models due to high K tensor kurtosis. So no rotation is needed for the K8/V4 path this PR implements.
Request: Latency benchmarks at 100K+ context comparing FP16 KVCache vs QuantizedKVCache(bits=(8,4)) would strengthen this PR. I can run these on my M2 Ultra if helpful.
Looks good to me.
Benchmark ResultsRan latency benchmarks per @Thump604's request. Machine: Apple Silicon, 32GB unified memory. Model: Decode Latency vs Context Length
Key finding: K8/V4 crosses over to faster-than-FP16 at ~16K context, where attention becomes memory-bandwidth-bound. At 32K it's essentially even. Q4 is consistently faster at long contexts (1.36x at 32K) since it reads the least data. Estimated KV Cache Memory
At 128K context on a 32GB machine with a ~4.5GB model, FP16 KV would use 16GB (total 20.5GB, tight). K8/V4 uses 10GB (14.5GB total, comfortable). Q4 uses 6GB (10.5GB, plenty of headroom). Output QualityAll three configs produce identical output for short generations: @Thump604 — if you're able to run these on your M2 Ultra with Qwen3.5-122B at 100K+ context, that would complement these results nicely. The longer the context, the more the bandwidth savings should show. |
|
Great data. The crossover at 16K matches what I would expect — that is where attention shifts from compute-bound to memory-bandwidth-bound, and reading less data (quantized) starts winning. The 38% memory savings at K8/V4 on Llama-3 (8 KV heads, D=128) is consistent. For Qwen3.5-122B (2 KV heads, D=256, only 12 of 48 layers use KV cache), the per-token savings are smaller in absolute terms but the headroom matters at 128K+ where every GB counts. I am running context sweep benchmarks on the 4B right now (different work). Once that finishes I will run 122B K8/V4 vs FP16 at 32K/64K/100K and post the results here. |
|
I built a similar K8/V4 asymmetric cache for Qwen3.5-122B on M2 Ultra 128GB. A few findings that may be useful:
The KVSplit tuple API ( One note for hybrid models (Qwen3.5 with GatedDeltaNet + Attention): only the attention layers use KV cache. The 36 GatedDeltaNet layers use fixed-size ArraysCache (not per-token). So the actual memory savings from KV quantization are limited to the 12 attention layers. On the 122B at 128K context, that is still ~1.8GB savings (from ~3GB to ~1.2GB for pure KV), which matters when the model weights already use 82GB of 128GB. |
Adopt upstream's simplified _is_batchable one-liner while keeping the kv_bits guard that disables batching when KV cache quantization is active. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
Two related features for quantized KV cache:
QuantizedRotatingKVCache: combines bounded-size rotation (
RotatingKVCache) with quantized storage (QuantizedKVCache). Fixes theNotImplementedErrorinRotatingKVCache.to_quantized(). Enables--max-kv-size+--kv-bitstogether.KVSplit: allows different quantization bits for keys vs values. Keys need higher precision (they determine attention routing) while values tolerate lower precision.
bitsandgroup_sizeaccept(key, value)tuples.Motivation
QuantizedRotatingKVCache: Users cannot combine
--max-kv-size(bounded context) with--kv-bits(quantized cache). This crashes withNotImplementedError("RotatingKVCache Quantization NYI"). This means bounded-context inference cannot benefit from KV cache quantization, leaving memory savings on the table. Related to #883 (OOM from unbounded KV cache growth).KVSplit: Research shows keys need higher precision than values for attention quality. The maintainer noted in #191 that this would be "pretty straightforward." For example,
kv_bits=(8, 4)gives 8-bit keys (accurate routing) with 4-bit values (memory savings), a practical sweet spot.Changes
mlx_lm/models/cache.py:QuantizedKVCache: addedkey_bits,value_bits,key_group_size,value_group_sizewith backward-compatiblebits/group_sizeproperties. Updatedupdate_and_fetchto quantize K/V with their respective params. Updatedmeta_statewith backward compat for the old 3-field format.RotatingKVCache.to_quantized(): now returnsQuantizedRotatingKVCacheinstead of raisingNotImplementedError.QuantizedRotatingKVCache(new): inherits fromRotatingKVCache, reusesupdate_and_fetch,size,make_mask,is_trimmable,trimfrom parent. Overrides_trim,_temporal_order,_update_concat,_update_in_place,state,meta_state,nbytesfor quantized tuple storage viatree_map.mlx_lm/models/base.py:quantized_scaled_dot_product_attention: addedvalue_group_sizeandvalue_bitsoptional params (default to key params for backward compat).scaled_dot_product_attention: passes separate K/V params viagetattrwith fallbacks.mlx_lm/generate.py: addedQuantizedRotatingKVCacheto imports.Test plan
All configurations tested on Qwen2.5-0.5B-Instruct-4bit and Llama-3-8B-Instruct-4bit:
kv_bits=4— quantized cache workskv_bits=8— quantized cache workskv_bits=(8, 4)— KVSplit: keys 8-bit, values 4-bitkv_bits=(4, 2)— KVSplit: keys 4-bit, values 2-bitmax_kv_size=128— rotating cache worksmax_kv_size=128, kv_bits=4— previously crashed, now worksmax_kv_size=128, kv_bits=(8, 4)— rotating + KVSplit worksmeta_stateround-trip with old 3-field format (backward compat)meta_stateround-trip with new 5-field format (KVSplit)max_sizewithkeep=4— rotation logic correctto_quantized()after cache has wrapped — preserves offset/idxto_quantized()calls (idempotent — returns self)Closes #191. Addresses #968 (maintainer feedback: inherits from RotatingKVCache instead of duplicating logic).
🤖 Generated with Claude Code