Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/entrypoints/openai/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result
result = apply_hf_chat_template(
tokenizer,
trust_remote_code=True,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)

message = choice.message
message = chat_completion.choices[0].message
Expand Down Expand Up @@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)

message = choice.message
message = chat_completion.choices[0].message
Expand Down
58 changes: 58 additions & 0 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
COMMAND_R_MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -703,19 +705,72 @@ def get_conversation(is_hf: bool):

vllm_result = apply_hf_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
chat_template=None,
tools=None,
add_generation_prompt=True,
)

assert hf_result == vllm_result


@pytest.mark.parametrize(
"model",
[
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
COMMAND_R_MODEL_ID, # tokenizer.chat_template is of type dict
])
@pytest.mark.parametrize("use_tools", [True, False])
def test_apply_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""

# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer

conversation = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello, how are you?"
},
]

tools = [{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema
}
}] if use_tools else None

chat_template = apply_hf_chat_template(
tokenizer,
conversation=conversation,
# test that chat_template is None. use default chat_template.
chat_template=None,
tools=tools,
add_generation_prompt=True,
)
assert isinstance(chat_template, str)


# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(QWEN25VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
Expand All @@ -740,8 +795,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):

resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
trust_remote_code=True,
)

assert resolved_format == expected_format
Expand Down Expand Up @@ -793,6 +850,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
chat_template,
"auto",
dummy_tokenizer,
trust_remote_code=True,
)

assert resolved_format == expected_format
5 changes: 4 additions & 1 deletion tests/tool_use/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],

# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
ARGS: list[str] = [
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
"256"
]

CONFIGS: dict[str, ServerConfig] = {
"hermes": {
Expand Down
144 changes: 105 additions & 39 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
Expand Down Expand Up @@ -30,7 +29,8 @@
InputAudio)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin)
from typing_extensions import Required, TypeAlias, TypedDict

from vllm.config import ModelConfig
Expand Down Expand Up @@ -306,24 +306,62 @@ def _detect_content_format(
return "openai"


def _resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template

# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
trust_remote_code=trust_remote_code,
)
if processor.chat_template is not None:
return processor.chat_template
except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path, exc_info=True)

# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)

return None


def _resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
tokenizer_chat_template = tokenizer.chat_template
else:
tokenizer_chat_template = None

jinja_text: Optional[str]
if isinstance(tokenizer_chat_template, str) and chat_template is None:
jinja_text = tokenizer_chat_template
elif (isinstance(tokenizer_chat_template, dict)
and chat_template in tokenizer_chat_template):
jinja_text = tokenizer_chat_template[chat_template]
hf_chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
trust_remote_code=trust_remote_code,
tools=tools,
)
else:
jinja_text = load_chat_template(chat_template, is_literal=True)
hf_chat_template = None

jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True))

detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string"))
Expand All @@ -332,17 +370,11 @@ def _resolve_chat_template_content_format(


@lru_cache
def resolve_chat_template_content_format(
def _log_chat_template_content_format(
chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
chat_template,
given_format,
tokenizer,
)

detected_format: ChatTemplateContentFormatOption,
):
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
Expand All @@ -360,6 +392,29 @@ def resolve_chat_template_content_format(
detected_format,
)


def resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool = False,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
given_format,
tokenizer,
trust_remote_code=trust_remote_code,
)

_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)

return detected_format


Expand Down Expand Up @@ -711,7 +766,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
f"{type(chat_template)} is not a valid chat template type")


def load_chat_template(
def _load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
Expand All @@ -724,7 +779,7 @@ def load_chat_template(
raise TypeError("chat_template is expected to be read directly "
"from its value")

return codecs.decode(chat_template, "unicode_escape")
return chat_template

try:
with open(chat_template) as f:
Expand All @@ -742,7 +797,18 @@ def load_chat_template(

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
return load_chat_template(chat_template, is_literal=True)
return _load_chat_template(chat_template, is_literal=True)


_cached_load_chat_template = lru_cache(_load_chat_template)


def load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
) -> Optional[str]:
return _cached_load_chat_template(chat_template, is_literal=is_literal)


# TODO: Let user specify how to insert multimodal tokens into prompt
Expand Down Expand Up @@ -1067,31 +1133,29 @@ def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool = False,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None:
chat_template = tokenizer.chat_template

# FIXME: Temporary workaround for
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
if chat_template is None:
try:
processor = cached_get_processor(tokenizer.name_or_path)
chat_template = processor.chat_template
except Exception:
pass
hf_chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
trust_remote_code=trust_remote_code,
)

if chat_template is None:
if hf_chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")

return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template,
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
)
Expand All @@ -1100,7 +1164,8 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
**kwargs: Any,
) -> list[int]:
if chat_template is not None:
Expand All @@ -1117,5 +1182,6 @@ def apply_mistral_chat_template(

return tokenizer.apply_chat_template(
messages=messages,
tools=tools,
**kwargs,
)
Loading