Skip to content

Commit 7158138

Browse files
CatherineSuexwu-intel
authored andcommitted
bugfix(OAI): Fix image_data processing for jinja chat templates (sgl-project#6877)
1 parent 26c8f51 commit 7158138

File tree

4 files changed

+442
-16
lines changed

4 files changed

+442
-16
lines changed

python/sglang/srt/openai_api/adapter.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,22 @@
7575
TopLogprob,
7676
UsageInfo,
7777
)
78+
from sglang.srt.openai_api.utils import (
79+
detect_template_content_format,
80+
process_content_for_template_format,
81+
)
7882
from sglang.srt.reasoning_parser import ReasoningParser
7983
from sglang.utils import convert_json_schema_to_str, get_exception_traceback
8084

8185
logger = logging.getLogger(__name__)
8286

8387
chat_template_name = None
8488

89+
# Global cache for template content format detection (one model/template per instance)
90+
# NOTE: A better approach would be to initialize the chat template format when the endpoint is created
91+
_cached_chat_template = None
92+
_cached_template_format = None
93+
8594

8695
class FileMetadata:
8796
def __init__(self, filename: str, purpose: str):
@@ -1000,23 +1009,42 @@ def v1_chat_generate_request(
10001009

10011010
if chat_template_name is None:
10021011
openai_compatible_messages = []
1012+
image_data = []
1013+
audio_data = []
1014+
modalities = []
1015+
1016+
# Detect template content format by analyzing the jinja template (cached globally)
1017+
global _cached_chat_template, _cached_template_format
1018+
current_template = tokenizer_manager.tokenizer.chat_template
1019+
1020+
if current_template != _cached_chat_template:
1021+
# Template changed or first time - analyze it
1022+
_cached_chat_template = current_template
1023+
_cached_template_format = detect_template_content_format(
1024+
current_template
1025+
)
1026+
logger.info(
1027+
f"Detected chat template content format: {_cached_template_format}"
1028+
)
1029+
1030+
template_content_format = _cached_template_format
10031031

10041032
for message in request.messages:
10051033
if message.content is None:
10061034
message.content = ""
1007-
msg_dict = message.dict()
1008-
if isinstance(msg_dict.get("content"), list):
1009-
for chunk in msg_dict["content"]:
1010-
if isinstance(chunk, dict) and chunk.get("type") == "text":
1011-
new_msg = msg_dict.copy()
1012-
new_msg["content"] = chunk["text"]
1013-
new_msg = {
1014-
k: v for k, v in new_msg.items() if v is not None
1015-
}
1016-
openai_compatible_messages.append(new_msg)
1017-
else:
1018-
msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
1019-
openai_compatible_messages.append(msg_dict)
1035+
msg_dict = message.model_dump()
1036+
1037+
# Process content based on detected template format
1038+
processed_msg = process_content_for_template_format(
1039+
msg_dict,
1040+
template_content_format,
1041+
image_data,
1042+
audio_data,
1043+
modalities,
1044+
)
1045+
openai_compatible_messages.append(processed_msg)
1046+
1047+
# Handle assistant prefix for continue_final_message
10201048
if (
10211049
openai_compatible_messages
10221050
and openai_compatible_messages[-1]["role"] == "assistant"
@@ -1070,9 +1098,9 @@ def v1_chat_generate_request(
10701098
if is_multimodal:
10711099
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
10721100
stop = request.stop
1073-
image_data = None
1074-
audio_data = None
1075-
modalities = []
1101+
image_data = image_data if image_data else None
1102+
audio_data = audio_data if audio_data else None
1103+
modalities = modalities if modalities else []
10761104
else:
10771105
conv = generate_chat_conv(request, chat_template_name)
10781106
# If we should continue the final assistant message, adjust the conversation.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
Utility functions for OpenAI API adapter.
3+
"""
4+
5+
import logging
6+
from typing import Dict, List
7+
8+
import jinja2.nodes
9+
import transformers.utils.chat_template_utils as hf_chat_utils
10+
11+
logger = logging.getLogger(__name__)
12+
13+
# ============================================================================
14+
# JINJA TEMPLATE CONTENT FORMAT DETECTION
15+
# ============================================================================
16+
#
17+
# This adapts vLLM's approach for detecting chat template content format:
18+
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
19+
# - Analyzes Jinja template AST to detect content iteration patterns
20+
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
21+
# - 'string' format: templates that expect simple string content
22+
# - Processes content accordingly to match template expectations
23+
24+
25+
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
26+
"""Check if node is a variable access like {{ varname }}"""
27+
if isinstance(node, jinja2.nodes.Name):
28+
return node.ctx == "load" and node.name == varname
29+
return False
30+
31+
32+
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
33+
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
34+
if isinstance(node, jinja2.nodes.Getitem):
35+
return (
36+
_is_var_access(node.node, varname)
37+
and isinstance(node.arg, jinja2.nodes.Const)
38+
and node.arg.value == key
39+
)
40+
41+
if isinstance(node, jinja2.nodes.Getattr):
42+
return _is_var_access(node.node, varname) and node.attr == key
43+
44+
return False
45+
46+
47+
def _is_var_or_elems_access(
48+
node: jinja2.nodes.Node,
49+
varname: str,
50+
key: str = None,
51+
) -> bool:
52+
"""Check if node accesses varname or varname[key] with filters/tests"""
53+
if isinstance(node, jinja2.nodes.Filter):
54+
return node.node is not None and _is_var_or_elems_access(
55+
node.node, varname, key
56+
)
57+
if isinstance(node, jinja2.nodes.Test):
58+
return _is_var_or_elems_access(node.node, varname, key)
59+
60+
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
61+
node.arg, jinja2.nodes.Slice
62+
):
63+
return _is_var_or_elems_access(node.node, varname, key)
64+
65+
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
66+
67+
68+
def _try_extract_ast(chat_template: str):
69+
"""Try to parse the Jinja template into an AST"""
70+
try:
71+
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
72+
return jinja_compiled.environment.parse(chat_template)
73+
except Exception as e:
74+
logger.debug(f"Error when compiling Jinja template: {e}")
75+
return None
76+
77+
78+
def detect_template_content_format(chat_template: str) -> str:
79+
"""
80+
Detect whether a chat template expects 'string' or 'openai' content format.
81+
82+
- 'string': content is a simple string (like DeepSeek templates)
83+
- 'openai': content is a list of structured dicts (like Llama4 templates)
84+
85+
Detection logic:
86+
- If template has loops like {%- for content in message['content'] -%} → 'openai'
87+
- Otherwise → 'string'
88+
"""
89+
jinja_ast = _try_extract_ast(chat_template)
90+
if jinja_ast is None:
91+
return "string"
92+
93+
try:
94+
# Look for patterns like: {%- for content in message['content'] -%}
95+
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
96+
loop_iter = loop_ast.iter
97+
98+
# Check if iterating over message['content'] or similar
99+
if _is_var_or_elems_access(loop_iter, "message", "content"):
100+
return "openai" # Found content iteration → openai format
101+
102+
return "string" # No content loops found → string format
103+
except Exception as e:
104+
logger.debug(f"Error when parsing AST of Jinja template: {e}")
105+
return "string"
106+
107+
108+
def process_content_for_template_format(
109+
msg_dict: dict,
110+
content_format: str,
111+
image_data: list,
112+
audio_data: list,
113+
modalities: list,
114+
) -> dict:
115+
"""
116+
Process message content based on detected template format.
117+
118+
Args:
119+
msg_dict: Message dictionary with content
120+
content_format: 'string' or 'openai' (detected via AST analysis)
121+
image_data: List to append extracted image URLs
122+
audio_data: List to append extracted audio URLs
123+
modalities: List to append modalities
124+
125+
Returns:
126+
Processed message dictionary
127+
"""
128+
if not isinstance(msg_dict.get("content"), list):
129+
# Already a string or None, no processing needed
130+
return {k: v for k, v in msg_dict.items() if v is not None}
131+
132+
if content_format == "openai":
133+
# OpenAI format: preserve structured content list, normalize types
134+
processed_content_parts = []
135+
for chunk in msg_dict["content"]:
136+
if isinstance(chunk, dict):
137+
chunk_type = chunk.get("type")
138+
139+
if chunk_type == "image_url":
140+
image_data.append(chunk["image_url"]["url"])
141+
if chunk.get("modalities"):
142+
modalities.append(chunk.get("modalities"))
143+
# Normalize to simple 'image' type for template compatibility
144+
processed_content_parts.append({"type": "image"})
145+
elif chunk_type == "audio_url":
146+
audio_data.append(chunk["audio_url"]["url"])
147+
# Normalize to simple 'audio' type
148+
processed_content_parts.append({"type": "audio"})
149+
else:
150+
# Keep other content as-is (text, etc.)
151+
processed_content_parts.append(chunk)
152+
153+
new_msg = {
154+
k: v for k, v in msg_dict.items() if v is not None and k != "content"
155+
}
156+
new_msg["content"] = processed_content_parts
157+
return new_msg
158+
159+
else: # content_format == "string"
160+
# String format: flatten to text only (for templates like DeepSeek)
161+
text_parts = []
162+
for chunk in msg_dict["content"]:
163+
if isinstance(chunk, dict) and chunk.get("type") == "text":
164+
text_parts.append(chunk["text"])
165+
# Note: For string format, we ignore images/audio since the template
166+
# doesn't expect structured content - multimodal placeholders would
167+
# need to be inserted differently
168+
169+
new_msg = msg_dict.copy()
170+
new_msg["content"] = " ".join(text_parts) if text_parts else ""
171+
new_msg = {k: v for k, v in new_msg.items() if v is not None}
172+
return new_msg

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TestFile:
5656
TestFile("test_mla_fp8.py", 93),
5757
TestFile("test_no_chunked_prefill.py", 108),
5858
TestFile("test_no_overlap_scheduler.py", 234),
59+
TestFile("test_openai_adapter.py", 1),
5960
TestFile("test_openai_function_calling.py", 60),
6061
TestFile("test_openai_server.py", 149),
6162
TestFile("test_penalty.py", 41),

0 commit comments

Comments
 (0)