1212
1313import mlx .core as mx
1414
15- # MambaCache was removed in mlx-lm 0.30.6 - make import conditional
15+ # MambaCache was removed in mlx-lm 0.30.6, fall back to ArraysCache
1616try :
1717 from mlx_lm .models .cache import MambaCache
18-
19- HAS_MAMBA_CACHE = True
2018except ImportError :
21- # Fallback for mlx-lm >= 0.30.6 where MambaCache was removed
2219 from mlx_lm .models .cache import ArraysCache as MambaCache
2320
24- HAS_MAMBA_CACHE = False
25-
2621logger = logging .getLogger (__name__ )
2722
2823
@@ -42,10 +37,9 @@ def __init__(self, left_padding: Optional[List[int]] = None, size: int = 2):
4237 left_padding: Amount of left padding for each sequence in batch
4338 size: Number of state arrays (default 2 for Mamba models)
4439 """
45- if HAS_MAMBA_CACHE :
46- super ().__init__ (left_padding = left_padding )
47- else :
48- super ().__init__ (size = size , left_padding = left_padding )
40+ # Always pass size - ArraysCache requires it, and MambaCache
41+ # (if it exists) inherits from ArraysCache
42+ super ().__init__ (size = size , left_padding = left_padding )
4943 self ._batch_size = len (left_padding ) if left_padding else 0
5044
5145 def extract (self , idx : int ) -> MambaCache :
@@ -59,10 +53,7 @@ def extract(self, idx: int) -> MambaCache:
5953 A new MambaCache with the extracted state
6054 """
6155 size = len (self .cache )
62- if HAS_MAMBA_CACHE :
63- cache = MambaCache ()
64- else :
65- cache = MambaCache (size = size )
56+ cache = MambaCache (size = size )
6657 # Extract the state arrays for this index
6758 cache .cache = [
6859 mx .contiguous (c [idx : idx + 1 ]) if c is not None else None
0 commit comments