Skip to content

feat: QuantizedRotatingKVCache + KVSplit (K/V different bits)#1074

Open
deceptech-packet-ninja wants to merge 3 commits intoml-explore:mainfrom
deceptech-packet-ninja:feat/kv-cache-improvements
Open

feat: QuantizedRotatingKVCache + KVSplit (K/V different bits)#1074
deceptech-packet-ninja wants to merge 3 commits intoml-explore:mainfrom
deceptech-packet-ninja:feat/kv-cache-improvements

Conversation

@deceptech-packet-ninja
Copy link
Copy Markdown

Summary

Two related features for quantized KV cache:

  1. QuantizedRotatingKVCache: combines bounded-size rotation (RotatingKVCache) with quantized storage (QuantizedKVCache). Fixes the NotImplementedError in RotatingKVCache.to_quantized(). Enables --max-kv-size + --kv-bits together.

  2. KVSplit: allows different quantization bits for keys vs values. Keys need higher precision (they determine attention routing) while values tolerate lower precision. bits and group_size accept (key, value) tuples.

Motivation

QuantizedRotatingKVCache: Users cannot combine --max-kv-size (bounded context) with --kv-bits (quantized cache). This crashes with NotImplementedError("RotatingKVCache Quantization NYI"). This means bounded-context inference cannot benefit from KV cache quantization, leaving memory savings on the table. Related to #883 (OOM from unbounded KV cache growth).

KVSplit: Research shows keys need higher precision than values for attention quality. The maintainer noted in #191 that this would be "pretty straightforward." For example, kv_bits=(8, 4) gives 8-bit keys (accurate routing) with 4-bit values (memory savings), a practical sweet spot.

Changes

mlx_lm/models/cache.py:

  • QuantizedKVCache: added key_bits, value_bits, key_group_size, value_group_size with backward-compatible bits/group_size properties. Updated update_and_fetch to quantize K/V with their respective params. Updated meta_state with backward compat for the old 3-field format.
  • RotatingKVCache.to_quantized(): now returns QuantizedRotatingKVCache instead of raising NotImplementedError.
  • QuantizedRotatingKVCache (new): inherits from RotatingKVCache, reuses update_and_fetch, size, make_mask, is_trimmable, trim from parent. Overrides _trim, _temporal_order, _update_concat, _update_in_place, state, meta_state, nbytes for quantized tuple storage via tree_map.

mlx_lm/models/base.py:

  • quantized_scaled_dot_product_attention: added value_group_size and value_bits optional params (default to key params for backward compat).
  • scaled_dot_product_attention: passes separate K/V params via getattr with fallbacks.

mlx_lm/generate.py: added QuantizedRotatingKVCache to imports.

Test plan

All configurations tested on Qwen2.5-0.5B-Instruct-4bit and Llama-3-8B-Instruct-4bit:

  • Standard BF16 — unchanged behavior
  • kv_bits=4 — quantized cache works
  • kv_bits=8 — quantized cache works
  • kv_bits=(8, 4) — KVSplit: keys 8-bit, values 4-bit
  • kv_bits=(4, 2) — KVSplit: keys 4-bit, values 2-bit
  • max_kv_size=128 — rotating cache works
  • max_kv_size=128, kv_bits=4previously crashed, now works
  • max_kv_size=128, kv_bits=(8, 4) — rotating + KVSplit works
  • meta_state round-trip with old 3-field format (backward compat)
  • meta_state round-trip with new 5-field format (KVSplit)
  • 100 decode steps past max_size with keep=4 — rotation logic correct
  • to_quantized() after cache has wrapped — preserves offset/idx
  • Repeated to_quantized() calls (idempotent — returns self)

Closes #191. Addresses #968 (maintainer feedback: inherits from RotatingKVCache instead of duplicating logic).

🤖 Generated with Claude Code

deceptech-packet-ninja and others added 2 commits March 30, 2026 14:01
Enables KV cache quantization in mlx_lm.server, closing ml-explore#1043.
Batching disabled when kv_bits is set (BatchQuantizedKVCache NYI).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- QuantizedRotatingKVCache: inherits from RotatingKVCache, stores
  quantized (data, scales, biases) tuples in a rotating buffer.
  Fixes RotatingKVCache.to_quantized() NotImplementedError.
  Enables --max-kv-size + --kv-bits together.

- KVSplit: QuantizedKVCache and QuantizedRotatingKVCache accept
  bits=(key_bits, value_bits) tuples for asymmetric quantization.
  quantized_scaled_dot_product_attention uses separate params for
  key and value matmuls. Full backward compatibility.

Closes ml-explore#191. Addresses ml-explore#883 (OOM from unbounded KV cache).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

@Thump604 Thump604 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run Qwen3.5-122B-A10B (5-bit, ~82GB weights) on M2 Ultra 128GB in production with continuous batching and MTP. KV cache quantization is on my critical path for 1M context support.

I independently built a similar asymmetric K8/V4 cache class and tested it extensively. A few findings that may be useful:

Architecture: My implementation dequantized the full cache to FP16 before passing to SDPA, which caused a 54% latency regression at 100K context. This PR's approach of keeping K/V in quantized tuple form and routing through mx.quantized_matmul is the right architecture since dequantization is kernel-fused. Nice work.

Hadamard rotation: I tested applying a Walsh-Hadamard transform before quantization to distribute outlier energy. At K8, it provides negligible quality improvement (1.01x MSE ratio vs no rotation). At K4, it helps but K4 fails on Qwen3.5 models due to high K tensor kurtosis. So no rotation is needed for the K8/V4 path this PR implements.

Request: Latency benchmarks at 100K+ context comparing FP16 KVCache vs QuantizedKVCache(bits=(8,4)) would strengthen this PR. I can run these on my M2 Ultra if helpful.

Looks good to me.

@deceptech-packet-ninja
Copy link
Copy Markdown
Author

Benchmark Results

Ran latency benchmarks per @Thump604's request. Machine: Apple Silicon, 32GB unified memory.

Model: mlx-community/Meta-Llama-3-8B-Instruct-4bit (32 layers, 8 KV heads, D=128)

Decode Latency vs Context Length

Context FP16 (ms) FP16 (t/s) K8/V4 (ms) K8/V4 (t/s) Q4 (ms) Q4 (t/s) K8/V4 vs FP16
1,024 49.0 20.4 51.9 19.3 53.8 18.6 0.94x
4,096 55.9 17.9 59.8 16.7 54.1 18.5 0.93x
16,384 80.0 12.5 75.4 13.3 74.6 13.4 1.06x
32,768 125.4 8.0 123.6 8.1 92.4 10.8 1.01x

Key finding: K8/V4 crosses over to faster-than-FP16 at ~16K context, where attention becomes memory-bandwidth-bound. At 32K it's essentially even. Q4 is consistently faster at long contexts (1.36x at 32K) since it reads the least data.

Estimated KV Cache Memory

Context FP16 K8/V4 Q4 K8/V4 savings
1K 128 MB 80 MB 48 MB 38%
4K 512 MB 320 MB 192 MB 38%
16K 2.0 GB 1.3 GB 768 MB 38%
32K 4.0 GB 2.5 GB 1.5 GB 38%
64K 8.0 GB 5.0 GB 3.0 GB 38%
128K 16.0 GB 10.0 GB 6.0 GB 38%

At 128K context on a 32GB machine with a ~4.5GB model, FP16 KV would use 16GB (total 20.5GB, tight). K8/V4 uses 10GB (14.5GB total, comfortable). Q4 uses 6GB (10.5GB, plenty of headroom).

Output Quality

All three configs produce identical output for short generations:

FP16:  Gravity is a force that pulls objects towards each other. It is what keeps us on the ground and what...
K8/V4: Gravity is a force that pulls objects towards each other. It is what keeps us on the ground and what...
Q4:    Gravity is a force that pulls objects towards each other. It is what keeps us on the ground and what...

@Thump604 — if you're able to run these on your M2 Ultra with Qwen3.5-122B at 100K+ context, that would complement these results nicely. The longer the context, the more the bandwidth savings should show.

@Thump604
Copy link
Copy Markdown

Great data. The crossover at 16K matches what I would expect — that is where attention shifts from compute-bound to memory-bandwidth-bound, and reading less data (quantized) starts winning.

The 38% memory savings at K8/V4 on Llama-3 (8 KV heads, D=128) is consistent. For Qwen3.5-122B (2 KV heads, D=256, only 12 of 48 layers use KV cache), the per-token savings are smaller in absolute terms but the headroom matters at 128K+ where every GB counts.

I am running context sweep benchmarks on the 4B right now (different work). Once that finishes I will run 122B K8/V4 vs FP16 at 32K/64K/100K and post the results here.

@Thump604
Copy link
Copy Markdown

Thump604 commented Apr 1, 2026

I built a similar K8/V4 asymmetric cache for Qwen3.5-122B on M2 Ultra 128GB. A few findings that may be useful:

  • Hadamard rotation before quantization provides negligible quality improvement at K8. Not worth the overhead.
  • The K8/V4 split is the right granularity. K needs higher precision for attention routing; V tolerates aggressive quantization.
  • With mx.compile fusion for the dequant+SDPA path, the latency overhead is manageable (~1.8x baseline vs ~2x without fusion).

The KVSplit tuple API (bits=(8,4)) is cleaner than what I had. Closing out my local approach in favor of this.

One note for hybrid models (Qwen3.5 with GatedDeltaNet + Attention): only the attention layers use KV cache. The 36 GatedDeltaNet layers use fixed-size ArraysCache (not per-token). So the actual memory savings from KV quantization are limited to the 12 attention layers. On the 122B at 128K context, that is still ~1.8GB savings (from ~3GB to ~1.2GB for pure KV), which matters when the model weights already use 82GB of 128GB.

Adopt upstream's simplified _is_batchable one-liner while keeping
the kv_bits guard that disables batching when KV cache quantization
is active.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Possibility for KVSplit support in mlx-lm?

2 participants