Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Any
from typing import cast, List, Optional, Any
import logging
from mlx_lm.models.cache import (
make_prompt_cache,
Expand Down Expand Up @@ -60,6 +60,9 @@ def __init__(
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
chunk_size: int,
turboquant: bool = False,
turboquant_fused: bool = True,
turboquant_fp16_layers: int = 4,
):
"""
Initialize the CacheWrapper.
Expand All @@ -68,18 +71,40 @@ def __init__(
model (nn.Module): The model to be cached.
max_kv_size (Optional[int]): Maximum size of the key-value cache.
chunk_size (int): Number of tokens per prefill chunk.
turboquant (bool): Use TurboQuant KV cache compression.
turboquant_fused (bool): Use fused attention path for TurboQuant.
turboquant_fp16_layers (int): Number of first and last layers to keep in FP16.
"""
# utilize a simple ordered list of tokens processed so far for cache invalidation checking
self.tokens: Optional[mx.array] = None
self.cache: List[Any] = make_prompt_cache(model, max_kv_size)
self.model = model

if turboquant:
from turboquant_mlx.adaptive import make_adaptive_cache
from turboquant_mlx.patch import apply_patch

apply_patch()

self.cache = make_adaptive_cache(
num_layers=len(model.layers),
bits=kv_bits or 3,
fp16_layers=turboquant_fp16_layers,
fused=turboquant_fused,
model=model,
)
else:
self.cache: List[Any] = make_prompt_cache(model, max_kv_size)

self.draft_model: Optional[nn.Module] = None
self.max_kv_size = max_kv_size
self.verbose = verbose
self.kv_cache_qtn_params = dict(
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
turboquant=turboquant,
turboquant_fused=turboquant_fused,
turboquant_fp16_layers=turboquant_fp16_layers,
)
self.chunk_size = chunk_size

Expand Down Expand Up @@ -219,7 +244,11 @@ def _prefill(
current_chunk = remaining_tokens[:current_chunk_size]

model(current_chunk[None], cache=cache)
maybe_quantize_kv_cache(prompt_cache=cache, **self.kv_cache_qtn_params)
qtn_params = self.kv_cache_qtn_params.copy()
qtn_params.pop("turboquant", None)
qtn_params.pop("turboquant_fused", None)
qtn_params.pop("turboquant_fp16_layers", None)
maybe_quantize_kv_cache(prompt_cache=cache, **qtn_params)
mx.eval([c.state for c in cache])

remaining_tokens = remaining_tokens[current_chunk_size:]
Expand All @@ -240,7 +269,19 @@ def _prefill(
)
num_tokens_in_cache = None
if num_tokens_in_cache is None:
self.cache = make_prompt_cache(self.model, self.max_kv_size)
if self.kv_cache_qtn_params.get("turboquant"):
from turboquant_mlx.adaptive import make_adaptive_cache

self.cache = make_adaptive_cache(
model=self.model,
num_layers=len(self.model.layers),
bits=self.kv_cache_qtn_params.get("kv_bits") or 3,
fused=cast(bool, self.kv_cache_qtn_params.get("turboquant_fused", True)),
fp16_layers=cast(int, self.kv_cache_qtn_params.get("turboquant_fp16_layers", 4)),
)
else:
self.cache = make_prompt_cache(self.model, self.max_kv_size)

self.tokens = None
else:
# Remember which tokens were processed so far, so that we can continue processing at a later point
Expand Down Expand Up @@ -274,7 +315,17 @@ def set_draft_model(self, draft_model: nn.Module):
# https://github.com/ml-explore/mlx-examples/blob/514502da22f0dc4c1ac439bdf78c07d5ec41acf7/llms/mlx_lm/utils.py#L381-L382
logger.info("Clearing current prompt cache and adding draft model to the cache")
self.tokens = None
self.cache: List[Any] = make_prompt_cache(self.model)
if self.kv_cache_qtn_params.get("turboquant"):
from turboquant_mlx.adaptive import make_adaptive_cache
self.cache = make_adaptive_cache(
model=self.model,
num_layers=len(self.model.layers),
bits=self.kv_cache_qtn_params.get("kv_bits") or 3,
fused=cast(bool, self.kv_cache_qtn_params.get("turboquant_fused", True)),
fp16_layers=cast(int, self.kv_cache_qtn_params.get("turboquant_fp16_layers", 4)),
)
else:
self.cache: List[Any] = make_prompt_cache(self.model)
if draft_model is not None:
self.cache += make_prompt_cache(draft_model)
self.draft_model = draft_model
Expand Down
20 changes: 15 additions & 5 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def load_model(
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
prefill_step_size: Optional[int] = None,
turboquant: bool = False,
turboquant_fused: bool = True,
turboquant_fp16_layers: int = 4,
) -> ModelKit | VisionModelKit:
"""
Load a language model or vision-language model from the specified path.
Expand All @@ -149,6 +152,9 @@ def load_model(
quantized_kv_start (Optional[int]): Step to begin KV cache quantization when enabled.
prefill_step_size (Optional[int]): Number of tokens to process per prefill chunk.
Defaults to PROMPT_PROCESSING_CHUNK_SIZE when None.
turboquant (bool): Use TurboQuant KV cache compression.
turboquant_fused (bool): Use fused attention path for TurboQuant.
turboquant_fp16_layers (int): Number of first and last layers to keep in FP16.

Returns:
ModelKit | VisionModelKit: An initialized model instance:
Expand Down Expand Up @@ -183,9 +189,9 @@ def warn_if_parallel(reason: str) -> None:
if "vision_config" in config_json and not ModelKit.is_supported_vision_arch(
model_type
):
if any([kv_bits, kv_group_size, quantized_kv_start]):
if any([kv_bits, kv_group_size, quantized_kv_start, turboquant]):
raise ValueError(
"MLX vision models do not currently support KV cache quantization"
"MLX vision models do not currently support KV cache quantization or TurboQuant"
)
if parallel_requested:
raise ValueError(
Expand All @@ -204,6 +210,7 @@ def warn_if_parallel(reason: str) -> None:
kv_bits,
kv_group_size,
quantized_kv_start,
turboquant=turboquant,
)

def is_batchable() -> bool:
Expand All @@ -221,10 +228,10 @@ def is_batchable() -> bool:
"this model architecture does not support continuous batching"
)
return False
# 2. KV cache quantization is not compatible with batching yet
if kv_bits is not None:
# 2. KV cache quantization or TurboQuant is not compatible with batching yet
if kv_bits is not None or turboquant:
warn_if_parallel(
"concurrency is not supported with KV Cache Quantization"
"concurrency is not supported with KV Cache Quantization or TurboQuant"
)
return False
# 3. Vision models are not compatible with batching yet
Expand Down Expand Up @@ -258,6 +265,9 @@ def is_batchable() -> bool:
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
turboquant=turboquant,
turboquant_fused=turboquant_fused,
turboquant_fp16_layers=turboquant_fp16_layers,
)
sanitize_eos_tokens(model_kit)
model_kit.start()
Expand Down
21 changes: 18 additions & 3 deletions mlx_engine/model_kit/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@ def _full_model_init(
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
turboquant: bool = False,
turboquant_fp16_layers: int = 4,
turboquant_fused: bool = True,
):
if kv_bits and max_kv_size is not None:
# Quantized KV cache is only supported for non-rotating KV cache
logger.warning("max_kv_size is ignored when using KV cache quantization")
if (kv_bits or turboquant) and max_kv_size is not None:
# Quantized KV cache or TurboQuant is only supported for non-rotating KV cache
logger.warning(
"max_kv_size is ignored when using KV cache quantization or TurboQuant"
)
max_kv_size = None
self.model_path = model_path
logger.info(f"Loading model from {model_path}...")
Expand All @@ -105,11 +110,15 @@ def _full_model_init(
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
turboquant=turboquant,
turboquant_fp16_layers=turboquant_fp16_layers,
turboquant_fused=turboquant_fused,
chunk_size=prefill_step_size,
)
self.kv_bits = kv_bits
self.kv_group_size = kv_group_size
self.quantized_kv_start = quantized_kv_start
self.turboquant = turboquant
vision_add_on_class = self.VISION_ADD_ON_MAP.get(self.model_type)
should_load_vision_add_on = (
vision_add_on_class is not None and "vision_config" in config_json
Expand All @@ -127,6 +136,9 @@ def __init__(
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
turboquant: bool = False,
turboquant_fp16_layers: int = 4,
turboquant_fused: bool = True,
):
self.generation_lock = threading.Lock()
self.pending_requests = {}
Expand All @@ -143,6 +155,9 @@ def __init__(
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
turboquant=turboquant,
turboquant_fp16_layers=turboquant_fp16_layers,
turboquant_fused=turboquant_fused,
)

def start(self):
Expand Down
12 changes: 12 additions & 0 deletions mlx_engine/utils/kv_cache_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# https://github.com/ml-explore/mlx/blob/f288db8d34c0bcfa0867b6458ab0277c5e86ed45/mlx/fast.cpp#L775
VALID_KV_GROUP_SIZE = (32, 64, 128)

# TurboQuant bits
VALID_TURBOQUANT_BITS = (1, 2, 3, 4)


def get_kv_cache_quantization_params(
kv_bits: Optional[int],
kv_group_size: Optional[int],
quantized_kv_start: Optional[int],
turboquant: bool = False,
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""
Validates and processes KV cache quantization parameters.
Expand All @@ -18,6 +22,7 @@ def get_kv_cache_quantization_params(
kv_bits: Number of bits for quantization. If None, disables quantization.
kv_group_size: Group size for quantization. Defaults to 64 if quantization enabled.
quantized_kv_start: Step to begin quantization. Defaults to 0 if quantization enabled.
turboquant: Whether TurboQuant is being used.

Returns:
Tuple of (kv_bits, kv_group_size, quantized_kv_start), all None if quantization disabled.
Expand All @@ -31,6 +36,13 @@ def get_kv_cache_quantization_params(
if kv_bits is None:
return None, None, None

if turboquant:
if kv_bits not in VALID_TURBOQUANT_BITS:
raise ValueError(
f"Invalid TurboQuant kv_bits value. Must be one of {VALID_TURBOQUANT_BITS}"
)
return kv_bits, None, None

# defaults taken from here:
# https://github.com/ml-explore/mlx-examples/blob/3d793ec/llms/mlx_lm/utils.py#L352-L353
if kv_group_size is None:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ typer==0.24.1
typing-inspection==0.4.2
urllib3==2.6.3
xxhash==3.6.0
turboquant-mlx @ git+https://github.com/arozanov/turboquant-mlx.git@cffcbf0beae92b2cc02de55a9aded19a96a484f6
Loading
Loading