Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
from enum import IntEnum, auto
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

from sglang.srt.openai_api.protocol import ChatCompletionRequest

Expand Down Expand Up @@ -407,6 +407,7 @@ def dict(self):

# A global registry for all conversation templates
chat_templates: Dict[str, Conversation] = {}
matching_function_registry: List[Callable] = []


def register_conv_template(template: Conversation, override: bool = False):
Expand All @@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
chat_templates[template.name] = template


def register_conv_template_matching_function(func):
matching_function_registry.append(func)


def get_conv_template_by_model_path(model_path):
for matching_func in matching_function_registry:
conv_name = matching_func(model_path)
if conv_name is not None:
return conv_name
return None


def chat_template_exists(template_name: str) -> bool:
return template_name in chat_templates

Expand Down Expand Up @@ -792,3 +805,86 @@ def generate_chat_conv(
audio_token="(<audio>./</audio>)",
)
)


@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if (
"llama" in model_path.lower()
and "3.2" in model_path.lower()
and "vision" in model_path.lower()
):
return "llama_3_vision"


@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if "janus" in model_path.lower():
return "janus-pro"


@register_conv_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return "vicuna_v1.1"
if "llava-v1.5" in model_path.lower():
return "vicuna_v1.1"
if "llava-next-video-7b" in model_path.lower():
return "vicuna_v1.1"


@register_conv_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return "llama-2"
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return "llama-2"
if "codellama" in model_path and "instruct" in model_path:
return "llama-2"


@register_conv_template_matching_function
def match_deepseek_vl(model_path: str):
model_path = model_path.lower()
if "deepseek" in model_path and "vl2" in model_path:
return "deepseek-vl2"


@register_conv_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
# Now the suffix for qwen2 chat model is "instruct"
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
return "gme-qwen2-vl"
if "qwen" in model_path and "vl" in model_path:
return "qwen2-vl"
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
):
return "chatml-llava"


@register_conv_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return "gemma-it"
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return "gemma-it"


@register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm-v" in model_path:
return "minicpmv"
elif "minicpm-o" in model_path:
return "minicpmo"
7 changes: 6 additions & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
Expand Down Expand Up @@ -575,6 +578,8 @@ def _launch_subprocesses(
load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path
)
else:
guess_chat_template_name_from_model_path(server_args.model_path)

if server_args.completion_template:
load_completion_template_for_openai_api(server_args.completion_template)
Expand Down
13 changes: 9 additions & 4 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
chat_template_exists,
generate_chat_conv,
generate_embedding_convs,
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.function_call_parser import FunctionCallParser
Expand Down Expand Up @@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else:
chat_template_name = chat_template_arg

# Check chat-template
# TODO:
# 1. Do not import any code from sglang.lang
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.

def guess_chat_template_name_from_model_path(model_path):
global chat_template_name
chat_template_name = get_conv_template_by_model_path(model_path)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. If the model_path from server_args is a local path, the guess might be problematic. Can you strengthen the guess utilizing additional info from hf_config, e.g. model_type or architectures?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adarshxs hi adarshxs, can you adapt this from your last pr?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure will do the same

Copy link
Copy Markdown
Collaborator

@mickqian mickqian Jun 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @adarshxs, how is it going? can we prioritize this? thanks

if chat_template_name is not None:
logger.info(
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
)


async def v1_files_create(
Expand Down
25 changes: 0 additions & 25 deletions test/srt/test_vision_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ def setUpClass(cls):
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"chatml-llava",
# "--log-requests",
],
)
cls.base_url += "/v1"

Expand Down Expand Up @@ -475,8 +470,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static",
"0.4",
],
Expand All @@ -496,8 +489,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static",
"0.4",
],
Expand All @@ -517,8 +508,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
"--context-length",
"300",
"--mem-fraction-static=0.80",
Expand Down Expand Up @@ -573,10 +562,6 @@ def setUpClass(cls):
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"llama_3_vision",
],
)
cls.base_url += "/v1"

Expand All @@ -596,8 +581,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"minicpmv",
"--mem-fraction-static",
"0.4",
],
Expand All @@ -617,8 +600,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"minicpmo",
"--mem-fraction-static",
"0.7",
],
Expand All @@ -642,8 +623,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"deepseek-vl2",
"--context-length",
"4096",
],
Expand Down Expand Up @@ -690,8 +669,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"janus-pro",
"--mem-fraction-static",
"0.4",
],
Expand Down Expand Up @@ -744,8 +721,6 @@ def setUpClass(cls):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"gemma-it",
"--mem-fraction-static",
"0.75",
"--enable-multimodal",
Expand Down
Loading