diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 6c69ea74c..123282a43 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -671,6 +671,8 @@ def _make_state_machine( return sm, sequences def _is_batchable(self, args): + if getattr(self.cli_args, "kv_bits", None) is not None: + return False return self.model_provider.is_batchable and args.seed is None def _generate(self): @@ -953,6 +955,11 @@ def progress(tokens_processed, tokens_total): cache += make_prompt_cache(self.model_provider.draft_model) # Process the prompt and generate tokens + kv_kwargs = {} + if getattr(self.cli_args, "kv_bits", None) is not None: + kv_kwargs["kv_bits"] = self.cli_args.kv_bits + kv_kwargs["kv_group_size"] = self.cli_args.kv_group_size + kv_kwargs["quantized_kv_start"] = self.cli_args.quantized_kv_start for gen in stream_generate( model=model, tokenizer=tokenizer, @@ -965,6 +972,7 @@ def progress(tokens_processed, tokens_total): num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, + **kv_kwargs, ): finish_reason = gen.finish_reason sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) @@ -1861,6 +1869,24 @@ def main(): default=2048, help="Step size for prefill processing (default: 2048)", ) + parser.add_argument( + "--kv-bits", + type=int, + default=None, + help="Number of bits for KV cache quantization. None means no quantization.", + ) + parser.add_argument( + "--kv-group-size", + type=int, + default=64, + help="Group size for KV cache quantization (default: 64)", + ) + parser.add_argument( + "--quantized-kv-start", + type=int, + default=0, + help="Step to begin quantizing the KV cache (default: 0)", + ) parser.add_argument( "--prompt-cache-size", type=int,