fix: handle 3D KV tensors in prefix cache for Qwen3.5 models#144
fix: handle 3D KV tensors in prefix cache for Qwen3.5 models#144jsirish wants to merge 2 commits intowaybarrios:mainfrom
Conversation
Qwen3.5 models (both MoE and dense variants) produce 3D KV cache tensors with shape (n_kv_heads, seq_len, head_dim) instead of the expected 4D shape (batch, n_kv_heads, seq_len, head_dim). This caused _extract_block_tensor_slice() to fail with: "Too many indices for array with 3 dimensions" The fix dynamically detects tensor dimensionality (ndim) and computes the sequence axis as ndim-2 (always second-to-last dimension), then builds slicing tuples accordingly. This works for both 3D and 4D tensors without breaking existing model support. Also fixes reconstruct_cache() which had the same hardcoded axis=2 assumption for concatenation and shape indexing. Fixes waybarrios#142
Qwen3.5 models (both MoE and dense) produce 3D KV cache tensors (n_kv_heads, seq_len, head_dim) instead of expected 4D. Three locations fixed: 1. _extract_block_tensor_slice() - tensor slicing 2. reconstruct_cache() - concatenation axis 3. SimpleKVCache.__init__() - offset shape index Relates to waybarrios#142
There was a problem hiding this comment.
Pull request overview
Fixes prefix-cache tensor slicing/concatenation for Qwen3.5-style KV caches where per-layer KV tensors are 3D (n_kv_heads, seq_len, head_dim) instead of the 4D (batch, n_kv_heads, seq_len, head_dim) shape assumed previously, restoring prefix-cache hits and performance.
Changes:
- Make
_extract_block_tensor_slice()build a dynamic slice tuple based onndim, so seq slicing works for both 3D and 4D KV tensors. - Make
reconstruct_cache()concatenate along the dynamically-detected sequence axis instead of hardcodingaxis=2. - Make reconstructed cache
offset/seq_lenderived fromshape[ndim - 2]rather thanshape[2].
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Determine the sequence dimension dynamically | ||
| ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape) | ||
| seq_axis = ndim - 2 # seq_len is always second-to-last | ||
| seq_len = keys.shape[seq_axis] if hasattr(keys, "shape") else 0 | ||
|
|
There was a problem hiding this comment.
ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape) can raise if keys lacks ndim and also has no shape attribute. Previously this function explicitly tolerated missing shape by treating seq_len as 0 and skipping slicing. Consider computing shape = getattr(keys, "shape", None) first and, if shape is None or len(shape) < 3, log/skip the layer (or return None) rather than throwing and disabling prefix caching.
| # Determine the sequence dimension dynamically | |
| ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape) | |
| seq_axis = ndim - 2 # seq_len is always second-to-last | |
| seq_len = keys.shape[seq_axis] if hasattr(keys, "shape") else 0 | |
| # Determine the sequence dimension dynamically, tolerating | |
| # missing or unexpected shape by skipping the layer instead | |
| shape = getattr(keys, "shape", None) | |
| if shape is None: | |
| logger.debug("Prefix cache: keys tensor has no 'shape'; skipping layer") | |
| continue | |
| if hasattr(keys, "ndim"): | |
| ndim = keys.ndim | |
| else: | |
| ndim = len(shape) | |
| # seq_len is always second-to-last: (..., seq_len, head_dim) | |
| seq_axis = ndim - 2 | |
| if seq_axis < 0 or seq_axis >= len(shape): | |
| logger.debug( | |
| f"Prefix cache: invalid seq_axis {seq_axis} for shape {shape}; skipping layer" | |
| ) | |
| continue | |
| seq_len = shape[seq_axis] |
| # Build a dynamic slice that works for both 3D and 4D tensors | ||
| slices = tuple( | ||
| slice(start_idx, actual_end) if i == seq_axis else slice(None) | ||
| for i in range(ndim) | ||
| ) | ||
| keys_slice = keys[slices] | ||
| values_slice = values[slices] | ||
| else: | ||
| keys_slice = keys[:, :, start_idx:end_idx, :] | ||
| values_slice = values[:, :, start_idx:end_idx, :] | ||
| slices = tuple( | ||
| slice(start_idx, end_idx) if i == seq_axis else slice(None) | ||
| for i in range(ndim) | ||
| ) | ||
| keys_slice = keys[slices] | ||
| values_slice = values[slices] |
There was a problem hiding this comment.
The dynamic slicing logic is duplicated in both branches. You can reduce the chance of inconsistencies by computing a single end = min(end_idx, seq_len) and building the slices tuple once, then applying it to both keys and values. This keeps the slicing behavior identical for the in-range vs out-of-range cases.
| ndim = layer_keys[0].ndim if hasattr(layer_keys[0], "ndim") else len(layer_keys[0].shape) | ||
| seq_axis = ndim - 2 # seq_len is always second-to-last | ||
| concat_keys = mx.concatenate(layer_keys, axis=seq_axis) | ||
| concat_values = mx.concatenate(layer_values, axis=seq_axis) |
There was a problem hiding this comment.
ndim = layer_keys[0].ndim if hasattr(layer_keys[0], "ndim") else len(layer_keys[0].shape) has the same unguarded len(...shape) fallback as _extract_block_tensor_slice, which can raise if an unexpected cache representation leaks through. Consider using shape = getattr(layer_keys[0], "shape", None) and validating shape is not None and len(shape) >= 3 before computing seq_axis = len(shape) - 2, otherwise fail reconstruction cleanly.
| @@ -665,11 +669,20 @@ def _extract_block_tensor_slice( | |||
| actual_end = min(end_idx, seq_len) | |||
| if start_idx >= actual_end: | |||
| continue | |||
| keys_slice = keys[:, :, start_idx:actual_end, :] | |||
| values_slice = values[:, :, start_idx:actual_end, :] | |||
| # Build a dynamic slice that works for both 3D and 4D tensors | |||
| slices = tuple( | |||
| slice(start_idx, actual_end) if i == seq_axis else slice(None) | |||
| for i in range(ndim) | |||
| ) | |||
| keys_slice = keys[slices] | |||
| values_slice = values[slices] | |||
| else: | |||
| keys_slice = keys[:, :, start_idx:end_idx, :] | |||
| values_slice = values[:, :, start_idx:end_idx, :] | |||
| slices = tuple( | |||
| slice(start_idx, end_idx) if i == seq_axis else slice(None) | |||
| for i in range(ndim) | |||
| ) | |||
| keys_slice = keys[slices] | |||
| values_slice = values[slices] | |||
There was a problem hiding this comment.
This PR adds support for 3D KV tensors, but there doesn’t appear to be test coverage exercising the 3D path (no tests reference _extract_block_tensor_slice/reconstruct_cache, and existing paged-cache tests use placeholder non-tensor cache_data). Adding a unit/integration test that stores and reconstructs a small 3D (heads, seq, dim) KV state would prevent regressions for Qwen3.5-style caches.
…5 models Qwen3.5 MoE models produce 3D KV tensors (n_kv_heads, seq_len, head_dim) instead of 4D. The prefix cache silently fell back to disabled. This fix dynamically detects tensor dimensionality and computes seq_axis = ndim-2. Result: 63% faster on cache hits for Qwen3.5-122B-A10B-4bit. Cherry-picked from: waybarrios#144 Original author: Jason Irish <jsirish> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Tried this with qwen 3 coder next, got a different error:
|
…efix cache The prefix cache hardcoded 4D KV tensor assumptions that crash on hybrid models like Qwen3.5 which mix standard attention (KVCache) and linear attention (ArraysCache/GatedDeltaNet) layers. Three issues fixed: 1. _extract_block_tensor_slice() used `keys[:, :, start:end, :]` (4-index) on 3D tensors `(n_kv_heads, seq_len, head_dim)` → "Too many indices" error. Now uses dynamic `ndim-2` for the sequence axis. 2. ArraysCache layers (SSM/linear-attention) have cumulative state that cannot be split into sequence-position blocks. Now detected via _is_kv_cache_layer() and skipped during block extraction. Full state stored separately per prefix for exact-match reconstruction. 3. reconstruct_cache() assumed all layers are KVCache. Now handles mixed caches: KV layers concatenated from blocks, non-KV layers restored via from_state(). Returns None when non-KV states are unavailable (partial prefix on hybrid model) to prevent incorrect output. Also fixes: - Memory leak: release_cache() now cleans up non-KV states - COW correctness: fork_cache() propagates non-KV states - State mutation safety: non-KV states deep-copied on store Fixes waybarrios#136, waybarrios#142. Supersedes waybarrios#144.
Critical fixes for prefix caching silently failing on Qwen3.5 and memory growing unbounded over time. Fixes: - _validate_cache: accept 3D KV tensors (n_kv_heads, seq, dim) used by Qwen3.5 — previously shape[0]=n_kv_heads was compared to 1, always failing and falling back to full prefill (issue waybarrios#144) - Handle QuantizedKVCache tuple keys without crashing shape check - _reconstruct_cache_from_states: use shape[1] as seq_len for 3D tensors, not shape[2] (which is head_dim) - BlockAwarePrefixCache.reconstruct_cache: select concat axis from tensor rank (axis=2 for 4D, axis=1 for 3D) - _extract_block_tensor_slice: same seq_axis fix for block slicing - evict_lru_blocks phase 2: when free queue is empty, force-evict LRU prefix-cached allocated blocks (ref_count=1) so memory is actually reclaimed under pressure - Upgrade silent cache validation debug logs to warnings New endpoints: - GET /v1/prefix-cache/stats — hit/miss/memory for debugging - DELETE /v1/prefix-cache — clear cache + flush Metal buffers Tests: 39 unit tests in tests/test_kv_shape_compat.py covering all model families (Llama/Mistral 4D, Qwen3.5 3D, GQA, hybrid Mamba, QuantizedKV) — no model download required.
|
Good fix — the root cause analysis is solid. Qwen3.5 models (both dense and MoE variants) produce 3D KV tensors I can corroborate the performance numbers. We run Qwen3.5-122B-A10B on M2 Ultra 128GB and the prefix cache speedup on warm requests is significant — your 63% improvement on M3 Ultra is consistent with what we'd expect. Interaction with our PRs: Our PR #165 also modifies Regarding @hlibr's comment about the 3D/4D concatenation error with Qwen 3 Coder — that could happen if different layers in the same model produce tensors with different ndim. If Qwen 3 Coder has some layers returning 4D and others 3D, the +1 for merge. Clean fix with good test data. |
Problem
Qwen3.5 models (both MoE and dense variants) produce 3D KV cache tensors with shape
(n_kv_heads, seq_len, head_dim)instead of the 4D shape(batch, n_kv_heads, seq_len, head_dim)thatprefix_cache.pyassumes.This causes
_extract_block_tensor_slice()to fail with:Prefix caching silently falls back to disabled, losing significant performance.
Fix
Three locations in
prefix_cache.pyhardcodedaxis=2/shape[2]/[:, :, start:end, :]for 4D tensors. The fix dynamically detects tensor dimensionality viandimand computesseq_axis = ndim - 2(sequence length is always the second-to-last dimension):_extract_block_tensor_slice()keys[:, :, start:end, :]range(ndim)reconstruct_cache()mx.concatenate(..., axis=2)mx.concatenate(..., axis=seq_axis)SimpleKVCache.__init__()keys.shape[2]keys.shape[ndim - 2]This is backward compatible — 4D tensors still use axis=2 (since
4-2=2), while 3D tensors correctly use axis=1.Test Results (Qwen3.5-122B-A10B-4bit on M3 Ultra)
Zero errors in server logs. Prefix cache now works correctly for Qwen3.5 models.
Related Issues