Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,20 @@ def setup_arg_parser():
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--turbo-kv-bits",
type=int,
help="TurboQuant KV cache compression bits (1-4). "
"3-bit gives 4.6x compression. Default: no compression.",
default=None,
)
parser.add_argument(
"--turbo-fp16-layers",
type=int,
help="Number of first/last layers to keep in FP16 "
"when using --turbo-kv-bits. Default: 1.",
default=1,
)
parser.add_argument(
"--draft-model",
type=str,
Expand Down Expand Up @@ -300,6 +314,7 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)



def generate_step(
prompt: mx.array,
model: nn.Module,
Expand All @@ -313,6 +328,8 @@ def generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
turbo_kv_bits: Optional[int] = None,
turbo_fp16_layers: int = 1,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
input_embeddings: Optional[mx.array] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
Expand All @@ -339,6 +356,11 @@ def generate_step(
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
turbo_kv_bits (int, optional): TurboQuant KV cache compression bits (1-4).
Uses PolarQuant with Hadamard rotation. 3-bit gives 4.6x compression.
None implies no TurboQuant. Default: ``None``.
turbo_fp16_layers (int): Number of first/last layers to keep in FP16 when
using TurboQuant. Default: ``1``.
prompt_progress_callback (Callable[[int, int], None]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
input_embeddings (mx.array, optional): Input embeddings to use instead of or in
Expand Down Expand Up @@ -368,6 +390,8 @@ def generate_step(
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
turbo_kv_bits=turbo_kv_bits,
turbo_fp16_layers=turbo_fp16_layers,
)

prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
Expand Down Expand Up @@ -1526,6 +1550,8 @@ def main():
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
turbo_kv_bits=args.turbo_kv_bits,
turbo_fp16_layers=args.turbo_fp16_layers,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
Expand Down
50 changes: 49 additions & 1 deletion mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
turbo_kv_bits: Optional[int] = None,
turbo_fp16_layers: int = 1,
) -> List[Any]:
"""
Construct the model's cache for use in generation.
Expand All @@ -27,11 +29,39 @@ def make_prompt_cache(
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
turbo_kv_bits (Optional[int]): If provided, use TurboQuant KV cache
compression at the given bit width (1-4). 3-bit gives 4.6x
compression. Default: ``None`` (no compression).
turbo_fp16_layers (int): Number of first/last layers to keep in FP16
when using TurboQuant. Default: ``1``.
"""
if hasattr(model, "make_cache"):
return model.make_cache()
default_cache = model.make_cache()
if turbo_kv_bits is not None:
# Check compatibility
if not isinstance(default_cache[0], KVCache):
raise ValueError(
f"[TurboQuant] Incompatible cache type: "
f"{type(default_cache[0]).__name__}. "
f"TurboQuant only works with standard multi-head "
f"attention (KVCache)."
)
else:
return default_cache

num_layers = len(model.layers)

if turbo_kv_bits is not None:
from mlx_lm.models.turboquant_cache import TurboQuantKVCache

caches = []
for i in range(num_layers):
if i < turbo_fp16_layers or i >= num_layers - turbo_fp16_layers:
caches.append(KVCache())
else:
caches.append(TurboQuantKVCache(bits=turbo_kv_bits))
return caches

if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
Expand Down Expand Up @@ -76,6 +106,13 @@ def load_prompt_cache(file_name, return_metadata=False):
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata

# Ensure TurboQuantKVCache is in globals for deserialization
if "TurboQuantKVCache" in classes and "TurboQuantKVCache" not in globals():
from mlx_lm.models.turboquant_cache import TurboQuantKVCache

globals()["TurboQuantKVCache"] = TurboQuantKVCache

cache = [
globals()[c].from_state(state, meta_state)
for c, state, meta_state in zip(classes, arrays, info)
Expand Down Expand Up @@ -390,6 +427,17 @@ def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
)
return quant_cache

def to_turbo_quantized(self, bits: int = 3):
from mlx_lm.models.turboquant_cache import TurboQuantKVCache

tq_cache = TurboQuantKVCache(bits=bits)
if self.keys is not None:
tq_cache.update_and_fetch(
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
return tq_cache

def make_mask(self, *args, **kwargs):
return create_attention_mask(*args, offset=self.offset, **kwargs)

Expand Down
Loading