Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/test_chat_template_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for chat template kwargs forwarding."""

from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient

import vllm_mlx.server as srv
from vllm_mlx.engine.base import GenerationOutput


def test_chat_completion_request_preserves_chat_template_kwargs():
request = srv.ChatCompletionRequest(
model="test-model",
messages=[srv.Message(role="user", content="Hello")],
chat_template_kwargs={"enable_thinking": False},
)

assert request.chat_template_kwargs == {"enable_thinking": False}


def test_batched_engine_applies_chat_template_kwargs():
with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False):
from vllm_mlx.engine.batched import BatchedEngine

engine = BatchedEngine("test-model")
engine._tokenizer = MagicMock()
engine._tokenizer.apply_chat_template.return_value = "prompt"

prompt = engine._apply_chat_template(
[{"role": "user", "content": "Hello"}],
chat_template_kwargs={"enable_thinking": False},
)

assert prompt == "prompt"
engine._tokenizer.apply_chat_template.assert_called_once()
assert (
engine._tokenizer.apply_chat_template.call_args.kwargs["enable_thinking"]
is False
)


def test_chat_completion_endpoint_forwards_chat_template_kwargs():
captured = {}

class FakeEngine:
model_name = "test-model"
is_mllm = False
preserve_native_tool_format = False

async def chat(self, messages, **kwargs):
captured["messages"] = messages
captured["kwargs"] = kwargs
return GenerationOutput(
text="ORBIT",
prompt_tokens=4,
completion_tokens=1,
finish_reason="stop",
)

client = TestClient(srv.app)
original_engine = srv._engine
original_model_name = srv._model_name
srv._engine = FakeEngine()
srv._model_name = "test-model"
try:
response = client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "Reply with ORBIT."}],
"max_tokens": 8,
"chat_template_kwargs": {"enable_thinking": False},
},
)
finally:
srv._engine = original_engine
srv._model_name = original_model_name

assert response.status_code == 200
assert captured["kwargs"]["chat_template_kwargs"] == {"enable_thinking": False}
assert response.json()["choices"][0]["message"]["content"] == "ORBIT"
90 changes: 90 additions & 0 deletions tests/test_paged_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,93 @@ def test_clear(self):
stats = cache.get_stats()
# After clear, null block is still allocated (vLLM style)
assert stats["allocated_blocks"] == 1 # only null block

def test_reconstructs_hybrid_cache_from_boundary_snapshot(self):
from mlx_lm.models.cache import ArraysCache, KVCache
import mlx.core as mx

from vllm_mlx.paged_cache import PagedCacheManager
from vllm_mlx.prefix_cache import BlockAwarePrefixCache

paged_manager = PagedCacheManager(block_size=4, max_blocks=10)
cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager)

tokens = list(range(8))
kv_keys = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3)
kv_values = mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3)
linear_state = [
mx.arange(1 * 3 * 8).reshape(1, 3, 8),
mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4),
]
extracted = [
{
"state": (kv_keys, kv_values),
"meta_state": "",
"class_ref": KVCache,
"class_name": "KVCache",
},
{
"state": linear_state,
"meta_state": "",
"class_ref": ArraysCache,
"class_name": "ArraysCache",
},
]

block_table = cache.store_cache("req-1", tokens, extracted)
first_block = paged_manager.allocated_blocks[block_table.block_ids[0]]
last_block = paged_manager.allocated_blocks[block_table.block_ids[-1]]

assert first_block.cache_data[0] is not None
assert first_block.cache_data[1] is None
assert last_block.cache_data[1] is not None

reconstructed = cache.reconstruct_cache(block_table)

assert reconstructed is not None
assert isinstance(reconstructed[0], KVCache)
assert isinstance(reconstructed[1], ArraysCache)
assert reconstructed[0].state[0].tolist() == kv_keys.tolist()
assert reconstructed[0].state[1].tolist() == kv_values.tolist()
assert reconstructed[1].state[0].tolist() == linear_state[0].tolist()
assert reconstructed[1].state[1].tolist() == linear_state[1].tolist()

def test_rejects_hybrid_prefix_without_boundary_snapshot(self):
from mlx_lm.models.cache import ArraysCache, KVCache
import mlx.core as mx

from vllm_mlx.paged_cache import BlockTable, PagedCacheManager
from vllm_mlx.prefix_cache import BlockAwarePrefixCache

paged_manager = PagedCacheManager(block_size=4, max_blocks=10)
cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager)

extracted = [
{
"state": (
mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3),
mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3),
),
"meta_state": "",
"class_ref": KVCache,
"class_name": "KVCache",
},
{
"state": [
mx.arange(1 * 3 * 8).reshape(1, 3, 8),
mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4),
],
"meta_state": "",
"class_ref": ArraysCache,
"class_name": "ArraysCache",
},
]

block_table = cache.store_cache("req-1", list(range(8)), extracted)
prefix_table = BlockTable(
request_id="req-prefix",
block_ids=[block_table.block_ids[0]],
num_tokens=4,
)

assert cache.reconstruct_cache(prefix_table) is None
Loading