Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"

Expand Down Expand Up @@ -716,6 +717,7 @@ def get_conversation(is_hf: bool):
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(QWEN25VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
Expand Down
30 changes: 15 additions & 15 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
Expand Down Expand Up @@ -312,16 +311,21 @@ def _resolve_chat_template_content_format(
tokenizer: AnyTokenizer,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
tokenizer_chat_template = tokenizer.chat_template
try:
# Prioritize processor's chat template for multi-modal models
processor = cached_get_processor(tokenizer.name_or_path)
hf_chat_template = processor.chat_template
except Exception:
hf_chat_template = tokenizer.chat_template
else:
tokenizer_chat_template = None
hf_chat_template = None

jinja_text: Optional[str]
if isinstance(tokenizer_chat_template, str) and chat_template is None:
jinja_text = tokenizer_chat_template
elif (isinstance(tokenizer_chat_template, dict)
and chat_template in tokenizer_chat_template):
jinja_text = tokenizer_chat_template[chat_template]
if isinstance(hf_chat_template, str) and chat_template is None:
jinja_text = hf_chat_template
elif (isinstance(hf_chat_template, dict)
and chat_template in hf_chat_template):
jinja_text = hf_chat_template[chat_template]
else:
jinja_text = load_chat_template(chat_template, is_literal=True)

Expand Down Expand Up @@ -724,7 +728,7 @@ def load_chat_template(
raise TypeError("chat_template is expected to be read directly "
"from its value")

return codecs.decode(chat_template, "unicode_escape")
return chat_template

try:
with open(chat_template) as f:
Expand Down Expand Up @@ -1071,17 +1075,13 @@ def apply_hf_chat_template(
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None:
chat_template = tokenizer.chat_template

# FIXME: Temporary workaround for
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
if chat_template is None:
try:
# Prioritize processor's chat template for multi-modal models
processor = cached_get_processor(tokenizer.name_or_path)
chat_template = processor.chat_template
except Exception:
pass
chat_template = tokenizer.chat_template

if chat_template is None:
raise ValueError(
Expand Down