Skip to content

Commit 87efd78

Browse files
committed
fix: handle 3D KV tensors in prefix cache for Qwen3.5 (PR waybarrios#144)
Qwen3.5 produces 3D tensors (n_kv_heads, seq_len, head_dim) instead of 4D.
1 parent 7203a0c commit 87efd78

1 file changed

Lines changed: 29 additions & 13 deletions

File tree

vllm_mlx/prefix_cache.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,13 @@ def _extract_block_tensor_slice(
652652

653653
keys, values = layer_state["state"]
654654

655-
# KV cache shape: (batch, n_kv_heads, seq_len, head_dim)
656-
# Slice along seq_len dimension (axis 2)
657-
seq_len = keys.shape[2] if hasattr(keys, "shape") else 0
655+
# KV cache shape varies by model architecture:
656+
# 4D: (batch, n_kv_heads, seq_len, head_dim) — most models
657+
# 3D: (n_kv_heads, seq_len, head_dim) — e.g. Qwen3.5
658+
# Determine the sequence dimension dynamically
659+
ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape)
660+
seq_axis = ndim - 2 # seq_len is always second-to-last
661+
seq_len = keys.shape[seq_axis] if hasattr(keys, "shape") else 0
658662

659663
if end_idx > seq_len:
660664
# Requested range extends beyond available data
@@ -665,11 +669,20 @@ def _extract_block_tensor_slice(
665669
actual_end = min(end_idx, seq_len)
666670
if start_idx >= actual_end:
667671
continue
668-
keys_slice = keys[:, :, start_idx:actual_end, :]
669-
values_slice = values[:, :, start_idx:actual_end, :]
672+
# Build a dynamic slice that works for both 3D and 4D tensors
673+
slices = tuple(
674+
slice(start_idx, actual_end) if i == seq_axis else slice(None)
675+
for i in range(ndim)
676+
)
677+
keys_slice = keys[slices]
678+
values_slice = values[slices]
670679
else:
671-
keys_slice = keys[:, :, start_idx:end_idx, :]
672-
values_slice = values[:, :, start_idx:end_idx, :]
680+
slices = tuple(
681+
slice(start_idx, end_idx) if i == seq_axis else slice(None)
682+
for i in range(ndim)
683+
)
684+
keys_slice = keys[slices]
685+
values_slice = values[slices]
673686

674687
block_slices.append((keys_slice, values_slice))
675688

@@ -821,10 +834,12 @@ def reconstruct_cache(
821834
if not layer_keys:
822835
continue
823836

824-
# Concatenate along sequence dimension (axis 2)
825-
# Shape: (batch, n_kv_heads, seq_len, head_dim)
826-
concat_keys = mx.concatenate(layer_keys, axis=2)
827-
concat_values = mx.concatenate(layer_values, axis=2)
837+
# Concatenate along sequence dimension
838+
# Shape varies: 4D (batch, heads, seq, dim) or 3D (heads, seq, dim)
839+
ndim = layer_keys[0].ndim if hasattr(layer_keys[0], "ndim") else len(layer_keys[0].shape)
840+
seq_axis = ndim - 2 # seq_len is always second-to-last
841+
concat_keys = mx.concatenate(layer_keys, axis=seq_axis)
842+
concat_values = mx.concatenate(layer_values, axis=seq_axis)
828843

829844
# Create KVCache object
830845
# Try to use mlx_lm's KVCache.from_state if available
@@ -833,7 +848,7 @@ def reconstruct_cache(
833848

834849
# Create new cache and set its state
835850
cache = KVCache()
836-
seq_len = concat_keys.shape[2]
851+
seq_len = concat_keys.shape[seq_axis]
837852

838853
# Set internal state directly
839854
# KVCache stores keys/values and offset
@@ -849,7 +864,8 @@ class SimpleKVCache:
849864
def __init__(self, keys, values):
850865
self.keys = keys
851866
self.values = values
852-
self.offset = keys.shape[2]
867+
ndim = keys.ndim if hasattr(keys, "ndim") else len(keys.shape)
868+
self.offset = keys.shape[ndim - 2]
853869

854870
@property
855871
def state(self):

0 commit comments

Comments
 (0)