Skip to content

fix: handle 3D KV tensors in prefix cache for Qwen3.5 models#144

Open
jsirish wants to merge 2 commits intowaybarrios:mainfrom
dynamic:fix/qwen35-prefix-cache-3d-kv-tensors
Open

fix: handle 3D KV tensors in prefix cache for Qwen3.5 models#144
jsirish wants to merge 2 commits intowaybarrios:mainfrom
dynamic:fix/qwen35-prefix-cache-3d-kv-tensors

Conversation

@jsirish
Copy link
Copy Markdown

@jsirish jsirish commented Mar 6, 2026

Problem

Qwen3.5 models (both MoE and dense variants) produce 3D KV cache tensors with shape (n_kv_heads, seq_len, head_dim) instead of the 4D shape (batch, n_kv_heads, seq_len, head_dim) that prefix_cache.py assumes.

This causes _extract_block_tensor_slice() to fail with:

Too many indices for array with 3 dimensions

Prefix caching silently falls back to disabled, losing significant performance.

Fix

Three locations in prefix_cache.py hardcoded axis=2 / shape[2] / [:, :, start:end, :] for 4D tensors. The fix dynamically detects tensor dimensionality via ndim and computes seq_axis = ndim - 2 (sequence length is always the second-to-last dimension):

Location Before After
_extract_block_tensor_slice() keys[:, :, start:end, :] Dynamic slice tuple via range(ndim)
reconstruct_cache() mx.concatenate(..., axis=2) mx.concatenate(..., axis=seq_axis)
SimpleKVCache.__init__() keys.shape[2] keys.shape[ndim - 2]

This is backward compatible — 4D tensors still use axis=2 (since 4-2=2), while 3D tensors correctly use axis=1.

Test Results (Qwen3.5-122B-A10B-4bit on M3 Ultra)

Request tok/s Notes
1st (cold) 27.9 No cache
2nd (same prompt) 45.4 ✅ Cache hit — 63% faster
3rd (different user msg) 23.4 Partial cache miss (expected)

Zero errors in server logs. Prefix cache now works correctly for Qwen3.5 models.

Related Issues

jsirish added 2 commits March 6, 2026 11:13
Qwen3.5 models (both MoE and dense variants) produce 3D KV cache
tensors with shape (n_kv_heads, seq_len, head_dim) instead of the
expected 4D shape (batch, n_kv_heads, seq_len, head_dim).

This caused _extract_block_tensor_slice() to fail with:
  "Too many indices for array with 3 dimensions"

The fix dynamically detects tensor dimensionality (ndim) and computes
the sequence axis as ndim-2 (always second-to-last dimension), then
builds slicing tuples accordingly. This works for both 3D and 4D
tensors without breaking existing model support.

Also fixes reconstruct_cache() which had the same hardcoded axis=2
assumption for concatenation and shape indexing.

Fixes waybarrios#142
Qwen3.5 models (both MoE and dense) produce 3D KV cache tensors
(n_kv_heads, seq_len, head_dim) instead of expected 4D.

Three locations fixed:
1. _extract_block_tensor_slice() - tensor slicing
2. reconstruct_cache() - concatenation axis
3. SimpleKVCache.__init__() - offset shape index

Relates to waybarrios#142
Copilot AI review requested due to automatic review settings March 6, 2026 17:24
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes prefix-cache tensor slicing/concatenation for Qwen3.5-style KV caches where per-layer KV tensors are 3D (n_kv_heads, seq_len, head_dim) instead of the 4D (batch, n_kv_heads, seq_len, head_dim) shape assumed previously, restoring prefix-cache hits and performance.

Changes:

  • Make _extract_block_tensor_slice() build a dynamic slice tuple based on ndim, so seq slicing works for both 3D and 4D KV tensors.
  • Make reconstruct_cache() concatenate along the dynamically-detected sequence axis instead of hardcoding axis=2.
  • Make reconstructed cache offset/seq_len derived from shape[ndim - 2] rather than shape[2].

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +658 to 662
# Determine the sequence dimension dynamically
ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape)
seq_axis = ndim - 2 # seq_len is always second-to-last
seq_len = keys.shape[seq_axis] if hasattr(keys, "shape") else 0

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape) can raise if keys lacks ndim and also has no shape attribute. Previously this function explicitly tolerated missing shape by treating seq_len as 0 and skipping slicing. Consider computing shape = getattr(keys, "shape", None) first and, if shape is None or len(shape) < 3, log/skip the layer (or return None) rather than throwing and disabling prefix caching.

Suggested change
# Determine the sequence dimension dynamically
ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape)
seq_axis = ndim - 2 # seq_len is always second-to-last
seq_len = keys.shape[seq_axis] if hasattr(keys, "shape") else 0
# Determine the sequence dimension dynamically, tolerating
# missing or unexpected shape by skipping the layer instead
shape = getattr(keys, "shape", None)
if shape is None:
logger.debug("Prefix cache: keys tensor has no 'shape'; skipping layer")
continue
if hasattr(keys, "ndim"):
ndim = keys.ndim
else:
ndim = len(shape)
# seq_len is always second-to-last: (..., seq_len, head_dim)
seq_axis = ndim - 2
if seq_axis < 0 or seq_axis >= len(shape):
logger.debug(
f"Prefix cache: invalid seq_axis {seq_axis} for shape {shape}; skipping layer"
)
continue
seq_len = shape[seq_axis]

Copilot uses AI. Check for mistakes.
Comment on lines +672 to +685
# Build a dynamic slice that works for both 3D and 4D tensors
slices = tuple(
slice(start_idx, actual_end) if i == seq_axis else slice(None)
for i in range(ndim)
)
keys_slice = keys[slices]
values_slice = values[slices]
else:
keys_slice = keys[:, :, start_idx:end_idx, :]
values_slice = values[:, :, start_idx:end_idx, :]
slices = tuple(
slice(start_idx, end_idx) if i == seq_axis else slice(None)
for i in range(ndim)
)
keys_slice = keys[slices]
values_slice = values[slices]
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic slicing logic is duplicated in both branches. You can reduce the chance of inconsistencies by computing a single end = min(end_idx, seq_len) and building the slices tuple once, then applying it to both keys and values. This keeps the slicing behavior identical for the in-range vs out-of-range cases.

Copilot uses AI. Check for mistakes.
Comment on lines +839 to +842
ndim = layer_keys[0].ndim if hasattr(layer_keys[0], "ndim") else len(layer_keys[0].shape)
seq_axis = ndim - 2 # seq_len is always second-to-last
concat_keys = mx.concatenate(layer_keys, axis=seq_axis)
concat_values = mx.concatenate(layer_values, axis=seq_axis)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim = layer_keys[0].ndim if hasattr(layer_keys[0], "ndim") else len(layer_keys[0].shape) has the same unguarded len(...shape) fallback as _extract_block_tensor_slice, which can raise if an unexpected cache representation leaks through. Consider using shape = getattr(layer_keys[0], "shape", None) and validating shape is not None and len(shape) >= 3 before computing seq_axis = len(shape) - 2, otherwise fail reconstruction cleanly.

Copilot uses AI. Check for mistakes.
Comment on lines 655 to +685
@@ -665,11 +669,20 @@ def _extract_block_tensor_slice(
actual_end = min(end_idx, seq_len)
if start_idx >= actual_end:
continue
keys_slice = keys[:, :, start_idx:actual_end, :]
values_slice = values[:, :, start_idx:actual_end, :]
# Build a dynamic slice that works for both 3D and 4D tensors
slices = tuple(
slice(start_idx, actual_end) if i == seq_axis else slice(None)
for i in range(ndim)
)
keys_slice = keys[slices]
values_slice = values[slices]
else:
keys_slice = keys[:, :, start_idx:end_idx, :]
values_slice = values[:, :, start_idx:end_idx, :]
slices = tuple(
slice(start_idx, end_idx) if i == seq_axis else slice(None)
for i in range(ndim)
)
keys_slice = keys[slices]
values_slice = values[slices]
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds support for 3D KV tensors, but there doesn’t appear to be test coverage exercising the 3D path (no tests reference _extract_block_tensor_slice/reconstruct_cache, and existing paged-cache tests use placeholder non-tensor cache_data). Adding a unit/integration test that stores and reconstructs a small 3D (heads, seq, dim) KV state would prevent regressions for Qwen3.5-style caches.

Copilot uses AI. Check for mistakes.
jackzampolin added a commit to jackzampolin/vllm-mlx that referenced this pull request Mar 11, 2026
…5 models

Qwen3.5 MoE models produce 3D KV tensors (n_kv_heads, seq_len, head_dim)
instead of 4D. The prefix cache silently fell back to disabled. This fix
dynamically detects tensor dimensionality and computes seq_axis = ndim-2.

Result: 63% faster on cache hits for Qwen3.5-122B-A10B-4bit.

Cherry-picked from: waybarrios#144
Original author: Jason Irish <jsirish>

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@hlibr
Copy link
Copy Markdown

hlibr commented Mar 13, 2026

Tried this with qwen 3 coder next, got a different error:

WARNING:vllm_mlx.prefix_cache:Failed to reconstruct cache: [concatenate] All the input arrays must have the same number of dimensions. However, got arrays with dimensions 3 and 4.

seanpianka added a commit to seanpianka/vllm-mlx that referenced this pull request Mar 14, 2026
)

Qwen3.5 produces 3D tensors (n_kv_heads, seq_len, head_dim) instead of 4D.
Thump604 added a commit to Thump604/vllm-mlx that referenced this pull request Mar 16, 2026
…efix cache

The prefix cache hardcoded 4D KV tensor assumptions that crash on hybrid
models like Qwen3.5 which mix standard attention (KVCache) and linear
attention (ArraysCache/GatedDeltaNet) layers.

Three issues fixed:

1. _extract_block_tensor_slice() used `keys[:, :, start:end, :]` (4-index)
   on 3D tensors `(n_kv_heads, seq_len, head_dim)` → "Too many indices"
   error. Now uses dynamic `ndim-2` for the sequence axis.

2. ArraysCache layers (SSM/linear-attention) have cumulative state that
   cannot be split into sequence-position blocks. Now detected via
   _is_kv_cache_layer() and skipped during block extraction. Full state
   stored separately per prefix for exact-match reconstruction.

3. reconstruct_cache() assumed all layers are KVCache. Now handles mixed
   caches: KV layers concatenated from blocks, non-KV layers restored
   via from_state(). Returns None when non-KV states are unavailable
   (partial prefix on hybrid model) to prevent incorrect output.

Also fixes:
- Memory leak: release_cache() now cleans up non-KV states
- COW correctness: fork_cache() propagates non-KV states
- State mutation safety: non-KV states deep-copied on store

Fixes waybarrios#136, waybarrios#142. Supersedes waybarrios#144.
rishuriya added a commit to rishuriya/vllm-mlx that referenced this pull request Mar 18, 2026
Critical fixes for prefix caching silently failing on Qwen3.5 and
memory growing unbounded over time.

Fixes:
- _validate_cache: accept 3D KV tensors (n_kv_heads, seq, dim) used
  by Qwen3.5 — previously shape[0]=n_kv_heads was compared to 1,
  always failing and falling back to full prefill (issue waybarrios#144)
- Handle QuantizedKVCache tuple keys without crashing shape check
- _reconstruct_cache_from_states: use shape[1] as seq_len for 3D
  tensors, not shape[2] (which is head_dim)
- BlockAwarePrefixCache.reconstruct_cache: select concat axis from
  tensor rank (axis=2 for 4D, axis=1 for 3D)
- _extract_block_tensor_slice: same seq_axis fix for block slicing
- evict_lru_blocks phase 2: when free queue is empty, force-evict
  LRU prefix-cached allocated blocks (ref_count=1) so memory is
  actually reclaimed under pressure
- Upgrade silent cache validation debug logs to warnings

New endpoints:
- GET /v1/prefix-cache/stats  — hit/miss/memory for debugging
- DELETE /v1/prefix-cache     — clear cache + flush Metal buffers

Tests: 39 unit tests in tests/test_kv_shape_compat.py covering all
model families (Llama/Mistral 4D, Qwen3.5 3D, GQA, hybrid Mamba,
QuantizedKV) — no model download required.
@Thump604
Copy link
Copy Markdown
Collaborator

Good fix — the root cause analysis is solid. Qwen3.5 models (both dense and MoE variants) produce 3D KV tensors (n_kv_heads, seq_len, head_dim) from their attention layers, and the hardcoded axis=2 / shape[2] / [:, :, start:end, :] slicing silently breaks prefix cache for these models. The dynamic ndim - 2 approach is correct and backward-compatible.

I can corroborate the performance numbers. We run Qwen3.5-122B-A10B on M2 Ultra 128GB and the prefix cache speedup on warm requests is significant — your 63% improvement on M3 Ultra is consistent with what we'd expect.

Interaction with our PRs: Our PR #165 also modifies prefix_cache.py, but for a different issue — it adds hybrid model support (separating KV layers from non-positional ArraysCache/SSM layers so they aren't incorrectly block-sliced). PR #165 currently keeps the hardcoded axis=2 for KV layers. These two fixes are complementary: yours fixes the tensor dimensionality assumption, ours fixes the cache-type discrimination. Whoever merges second would need a straightforward rebase, but there are no logical conflicts.

Regarding @hlibr's comment about the 3D/4D concatenation error with Qwen 3 Coder — that could happen if different layers in the same model produce tensors with different ndim. If Qwen 3 Coder has some layers returning 4D and others 3D, the reconstruct_cache concatenation would fail when mixing them. Might be worth adding a guard that checks ndim consistency per-layer, or squeezing/unsqueezing to normalize before concatenation. But that's an edge case for a follow-up, not a blocker for this PR.

+1 for merge. Clean fix with good test data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Prefix cache tensor slicing fails for Qwen3.5 MoE models (3D KV cache)

4 participants