@@ -652,9 +652,13 @@ def _extract_block_tensor_slice(
652652
653653 keys , values = layer_state ["state" ]
654654
655- # KV cache shape: (batch, n_kv_heads, seq_len, head_dim)
656- # Slice along seq_len dimension (axis 2)
657- seq_len = keys .shape [2 ] if hasattr (keys , "shape" ) else 0
655+ # KV cache shape varies by model architecture:
656+ # 4D: (batch, n_kv_heads, seq_len, head_dim) — most models
657+ # 3D: (n_kv_heads, seq_len, head_dim) — e.g. Qwen3.5
658+ # Determine the sequence dimension dynamically
659+ ndim = keys .ndim if hasattr (keys , "ndim" ) else len (keys .shape )
660+ seq_axis = ndim - 2 # seq_len is always second-to-last
661+ seq_len = keys .shape [seq_axis ] if hasattr (keys , "shape" ) else 0
658662
659663 if end_idx > seq_len :
660664 # Requested range extends beyond available data
@@ -665,11 +669,20 @@ def _extract_block_tensor_slice(
665669 actual_end = min (end_idx , seq_len )
666670 if start_idx >= actual_end :
667671 continue
668- keys_slice = keys [:, :, start_idx :actual_end , :]
669- values_slice = values [:, :, start_idx :actual_end , :]
672+ # Build a dynamic slice that works for both 3D and 4D tensors
673+ slices = tuple (
674+ slice (start_idx , actual_end ) if i == seq_axis else slice (None )
675+ for i in range (ndim )
676+ )
677+ keys_slice = keys [slices ]
678+ values_slice = values [slices ]
670679 else :
671- keys_slice = keys [:, :, start_idx :end_idx , :]
672- values_slice = values [:, :, start_idx :end_idx , :]
680+ slices = tuple (
681+ slice (start_idx , end_idx ) if i == seq_axis else slice (None )
682+ for i in range (ndim )
683+ )
684+ keys_slice = keys [slices ]
685+ values_slice = values [slices ]
673686
674687 block_slices .append ((keys_slice , values_slice ))
675688
@@ -821,10 +834,12 @@ def reconstruct_cache(
821834 if not layer_keys :
822835 continue
823836
824- # Concatenate along sequence dimension (axis 2)
825- # Shape: (batch, n_kv_heads, seq_len, head_dim)
826- concat_keys = mx .concatenate (layer_keys , axis = 2 )
827- concat_values = mx .concatenate (layer_values , axis = 2 )
837+ # Concatenate along sequence dimension
838+ # Shape varies: 4D (batch, heads, seq, dim) or 3D (heads, seq, dim)
839+ ndim = layer_keys [0 ].ndim if hasattr (layer_keys [0 ], "ndim" ) else len (layer_keys [0 ].shape )
840+ seq_axis = ndim - 2 # seq_len is always second-to-last
841+ concat_keys = mx .concatenate (layer_keys , axis = seq_axis )
842+ concat_values = mx .concatenate (layer_values , axis = seq_axis )
828843
829844 # Create KVCache object
830845 # Try to use mlx_lm's KVCache.from_state if available
@@ -833,7 +848,7 @@ def reconstruct_cache(
833848
834849 # Create new cache and set its state
835850 cache = KVCache ()
836- seq_len = concat_keys .shape [2 ]
851+ seq_len = concat_keys .shape [seq_axis ]
837852
838853 # Set internal state directly
839854 # KVCache stores keys/values and offset
@@ -849,7 +864,8 @@ class SimpleKVCache:
849864 def __init__ (self , keys , values ):
850865 self .keys = keys
851866 self .values = values
852- self .offset = keys .shape [2 ]
867+ ndim = keys .ndim if hasattr (keys , "ndim" ) else len (keys .shape )
868+ self .offset = keys .shape [ndim - 2 ]
853869
854870 @property
855871 def state (self ):
0 commit comments