Skip to content

Commit 2c3ea29

Browse files
authored
[Feature] support auto chat template (#4949)
1 parent 5bb0acc commit 2c3ea29

File tree

4 files changed

+112
-31
lines changed

4 files changed

+112
-31
lines changed

python/sglang/srt/conversation.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
1818
import dataclasses
1919
from enum import IntEnum, auto
20-
from typing import Dict, List, Optional, Tuple, Union
20+
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
from sglang.srt.openai_api.protocol import ChatCompletionRequest
2323

@@ -407,6 +407,7 @@ def dict(self):
407407

408408
# A global registry for all conversation templates
409409
chat_templates: Dict[str, Conversation] = {}
410+
matching_function_registry: List[Callable] = []
410411

411412

412413
def register_conv_template(template: Conversation, override: bool = False):
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
419420
chat_templates[template.name] = template
420421

421422

423+
def register_conv_template_matching_function(func):
424+
matching_function_registry.append(func)
425+
426+
427+
def get_conv_template_by_model_path(model_path):
428+
for matching_func in matching_function_registry:
429+
conv_name = matching_func(model_path)
430+
if conv_name is not None:
431+
return conv_name
432+
return None
433+
434+
422435
def chat_template_exists(template_name: str) -> bool:
423436
return template_name in chat_templates
424437

@@ -792,3 +805,86 @@ def generate_chat_conv(
792805
audio_token="(<audio>./</audio>)",
793806
)
794807
)
808+
809+
810+
@register_conv_template_matching_function
811+
def match_deepseek_janus_pro(model_path: str):
812+
if (
813+
"llama" in model_path.lower()
814+
and "3.2" in model_path.lower()
815+
and "vision" in model_path.lower()
816+
):
817+
return "llama_3_vision"
818+
819+
820+
@register_conv_template_matching_function
821+
def match_deepseek_janus_pro(model_path: str):
822+
if "janus" in model_path.lower():
823+
return "janus-pro"
824+
825+
826+
@register_conv_template_matching_function
827+
def match_vicuna(model_path: str):
828+
if "vicuna" in model_path.lower():
829+
return "vicuna_v1.1"
830+
if "llava-v1.5" in model_path.lower():
831+
return "vicuna_v1.1"
832+
if "llava-next-video-7b" in model_path.lower():
833+
return "vicuna_v1.1"
834+
835+
836+
@register_conv_template_matching_function
837+
def match_llama2_chat(model_path: str):
838+
model_path = model_path.lower()
839+
if "llama-2" in model_path and "chat" in model_path:
840+
return "llama-2"
841+
if (
842+
"mistral" in model_path or "mixtral" in model_path
843+
) and "instruct" in model_path:
844+
return "llama-2"
845+
if "codellama" in model_path and "instruct" in model_path:
846+
return "llama-2"
847+
848+
849+
@register_conv_template_matching_function
850+
def match_deepseek_vl(model_path: str):
851+
model_path = model_path.lower()
852+
if "deepseek" in model_path and "vl2" in model_path:
853+
return "deepseek-vl2"
854+
855+
856+
@register_conv_template_matching_function
857+
def match_chat_ml(model_path: str):
858+
# import pdb;pdb.set_trace()
859+
model_path = model_path.lower()
860+
# Now the suffix for qwen2 chat model is "instruct"
861+
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
862+
return "gme-qwen2-vl"
863+
if "qwen" in model_path and "vl" in model_path:
864+
return "qwen2-vl"
865+
if (
866+
"llava-v1.6-34b" in model_path
867+
or "llava-v1.6-yi-34b" in model_path
868+
or "llava-next-video-34b" in model_path
869+
or "llava-onevision-qwen2" in model_path
870+
):
871+
return "chatml-llava"
872+
873+
874+
@register_conv_template_matching_function
875+
def match_gemma_it(model_path: str):
876+
model_path = model_path.lower()
877+
if "gemma" in model_path and "it" in model_path:
878+
return "gemma-it"
879+
if "gemma-3" in model_path and "1b" not in model_path:
880+
# gemma-3-1b-it is completion model
881+
return "gemma-it"
882+
883+
884+
@register_conv_template_matching_function
885+
def match_openbmb_minicpm(model_path: str):
886+
model_path = model_path.lower()
887+
if "minicpm-v" in model_path:
888+
return "minicpmv"
889+
elif "minicpm-o" in model_path:
890+
return "minicpmo"

python/sglang/srt/entrypoints/engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@
5858
)
5959
from sglang.srt.managers.scheduler import run_scheduler_process
6060
from sglang.srt.managers.tokenizer_manager import TokenizerManager
61-
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
61+
from sglang.srt.openai_api.adapter import (
62+
guess_chat_template_name_from_model_path,
63+
load_chat_template_for_openai_api,
64+
)
6265
from sglang.srt.server_args import PortArgs, ServerArgs
6366
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
6467
from sglang.srt.utils import (
@@ -584,6 +587,8 @@ def _launch_subprocesses(
584587
load_chat_template_for_openai_api(
585588
tokenizer_manager, server_args.chat_template, server_args.model_path
586589
)
590+
else:
591+
guess_chat_template_name_from_model_path(server_args.model_path)
587592

588593
if server_args.completion_template:
589594
load_completion_template_for_openai_api(server_args.completion_template)

python/sglang/srt/openai_api/adapter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
chat_template_exists,
3737
generate_chat_conv,
3838
generate_embedding_convs,
39+
get_conv_template_by_model_path,
3940
register_conv_template,
4041
)
4142
from sglang.srt.function_call_parser import FunctionCallParser
@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
163164
else:
164165
chat_template_name = chat_template_arg
165166

166-
# Check chat-template
167-
# TODO:
168-
# 1. Do not import any code from sglang.lang
169-
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
167+
168+
def guess_chat_template_name_from_model_path(model_path):
169+
global chat_template_name
170+
chat_template_name = get_conv_template_by_model_path(model_path)
171+
if chat_template_name is not None:
172+
logger.info(
173+
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
174+
)
170175

171176

172177
async def v1_files_create(

test/srt/test_vision_openai_server.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ def setUpClass(cls):
4747
cls.base_url,
4848
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
4949
api_key=cls.api_key,
50-
other_args=[
51-
"--chat-template",
52-
"chatml-llava",
53-
# "--log-requests",
54-
],
5550
)
5651
cls.base_url += "/v1"
5752

@@ -475,8 +470,6 @@ def setUpClass(cls):
475470
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
476471
api_key=cls.api_key,
477472
other_args=[
478-
"--chat-template",
479-
"qwen2-vl",
480473
"--mem-fraction-static",
481474
"0.4",
482475
],
@@ -496,8 +489,6 @@ def setUpClass(cls):
496489
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
497490
api_key=cls.api_key,
498491
other_args=[
499-
"--chat-template",
500-
"qwen2-vl",
501492
"--mem-fraction-static",
502493
"0.4",
503494
],
@@ -517,8 +508,6 @@ def setUpClass(cls):
517508
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
518509
api_key=cls.api_key,
519510
other_args=[
520-
"--chat-template",
521-
"qwen2-vl",
522511
"--context-length",
523512
"300",
524513
"--mem-fraction-static=0.80",
@@ -573,10 +562,6 @@ def setUpClass(cls):
573562
cls.base_url,
574563
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
575564
api_key=cls.api_key,
576-
other_args=[
577-
"--chat-template",
578-
"llama_3_vision",
579-
],
580565
)
581566
cls.base_url += "/v1"
582567

@@ -596,8 +581,6 @@ def setUpClass(cls):
596581
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
597582
other_args=[
598583
"--trust-remote-code",
599-
"--chat-template",
600-
"minicpmv",
601584
"--mem-fraction-static",
602585
"0.4",
603586
],
@@ -617,8 +600,6 @@ def setUpClass(cls):
617600
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
618601
other_args=[
619602
"--trust-remote-code",
620-
"--chat-template",
621-
"minicpmo",
622603
"--mem-fraction-static",
623604
"0.7",
624605
],
@@ -642,8 +623,6 @@ def setUpClass(cls):
642623
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
643624
other_args=[
644625
"--trust-remote-code",
645-
"--chat-template",
646-
"deepseek-vl2",
647626
"--context-length",
648627
"4096",
649628
],
@@ -690,8 +669,6 @@ def setUpClass(cls):
690669
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
691670
other_args=[
692671
"--trust-remote-code",
693-
"--chat-template",
694-
"janus-pro",
695672
"--mem-fraction-static",
696673
"0.4",
697674
],
@@ -744,8 +721,6 @@ def setUpClass(cls):
744721
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
745722
other_args=[
746723
"--trust-remote-code",
747-
"--chat-template",
748-
"gemma-it",
749724
"--mem-fraction-static",
750725
"0.75",
751726
"--enable-multimodal",

0 commit comments

Comments
 (0)