Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,22 @@
TopLogprob,
UsageInfo,
)
from sglang.srt.openai_api.utils import (
detect_template_content_format,
process_content_for_template_format,
)
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str, get_exception_traceback

logger = logging.getLogger(__name__)

chat_template_name = None

# Global cache for template content format detection (one model/template per instance)
# NOTE: A better approach would be to initialize the chat template format when the endpoint is created
_cached_chat_template = None
_cached_template_format = None


class FileMetadata:
def __init__(self, filename: str, purpose: str):
Expand Down Expand Up @@ -1000,23 +1009,42 @@ def v1_chat_generate_request(

if chat_template_name is None:
openai_compatible_messages = []
image_data = []
audio_data = []
modalities = []

# Detect template content format by analyzing the jinja template (cached globally)
global _cached_chat_template, _cached_template_format
current_template = tokenizer_manager.tokenizer.chat_template

if current_template != _cached_chat_template:
# Template changed or first time - analyze it
_cached_chat_template = current_template
_cached_template_format = detect_template_content_format(
current_template
)
logger.info(
f"Detected chat template content format: {_cached_template_format}"
)

template_content_format = _cached_template_format

for message in request.messages:
if message.content is None:
message.content = ""
msg_dict = message.dict()
if isinstance(msg_dict.get("content"), list):
for chunk in msg_dict["content"]:
if isinstance(chunk, dict) and chunk.get("type") == "text":
new_msg = msg_dict.copy()
new_msg["content"] = chunk["text"]
new_msg = {
k: v for k, v in new_msg.items() if v is not None
}
openai_compatible_messages.append(new_msg)
else:
msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
openai_compatible_messages.append(msg_dict)
msg_dict = message.model_dump()

# Process content based on detected template format
processed_msg = process_content_for_template_format(
msg_dict,
template_content_format,
image_data,
audio_data,
modalities,
)
openai_compatible_messages.append(processed_msg)

# Handle assistant prefix for continue_final_message
if (
openai_compatible_messages
and openai_compatible_messages[-1]["role"] == "assistant"
Expand Down Expand Up @@ -1070,9 +1098,9 @@ def v1_chat_generate_request(
if is_multimodal:
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
stop = request.stop
image_data = None
audio_data = None
modalities = []
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
else:
conv = generate_chat_conv(request, chat_template_name)
# If we should continue the final assistant message, adjust the conversation.
Expand Down
172 changes: 172 additions & 0 deletions python/sglang/srt/openai_api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
Utility functions for OpenAI API adapter.
"""

import logging
from typing import Dict, List

import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils

logger = logging.getLogger(__name__)

# ============================================================================
# JINJA TEMPLATE CONTENT FORMAT DETECTION
# ============================================================================
#
# This adapts vLLM's approach for detecting chat template content format:
# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
# - Analyzes Jinja template AST to detect content iteration patterns
# - 'openai' format: templates with {%- for content in message['content'] -%} loops
# - 'string' format: templates that expect simple string content
# - Processes content accordingly to match template expectations


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
"""Check if node is a variable access like {{ varname }}"""
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
"""Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)

if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key

return False


def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str = None,
) -> bool:
"""Check if node accesses varname or varname[key] with filters/tests"""
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)

if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)

return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)


def _try_extract_ast(chat_template: str):
"""Try to parse the Jinja template into an AST"""
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception as e:
logger.debug(f"Error when compiling Jinja template: {e}")
return None


def detect_template_content_format(chat_template: str) -> str:
"""
Detect whether a chat template expects 'string' or 'openai' content format.

- 'string': content is a simple string (like DeepSeek templates)
- 'openai': content is a list of structured dicts (like Llama4 templates)

Detection logic:
- If template has loops like {%- for content in message['content'] -%} → 'openai'
- Otherwise → 'string'
"""
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return "string"

try:
# Look for patterns like: {%- for content in message['content'] -%}
for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter

# Check if iterating over message['content'] or similar
if _is_var_or_elems_access(loop_iter, "message", "content"):
return "openai" # Found content iteration → openai format

return "string" # No content loops found → string format
except Exception as e:
logger.debug(f"Error when parsing AST of Jinja template: {e}")
return "string"


def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
audio_data: list,
modalities: list,
) -> dict:
"""
Process message content based on detected template format.

Args:
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities

Returns:
Processed message dictionary
"""
if not isinstance(msg_dict.get("content"), list):
# Already a string or None, no processing needed
return {k: v for k, v in msg_dict.items() if v is not None}

if content_format == "openai":
# OpenAI format: preserve structured content list, normalize types
processed_content_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict):
chunk_type = chunk.get("type")

if chunk_type == "image_url":
image_data.append(chunk["image_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type
processed_content_parts.append({"type": "audio"})
else:
# Keep other content as-is (text, etc.)
processed_content_parts.append(chunk)

new_msg = {
k: v for k, v in msg_dict.items() if v is not None and k != "content"
}
new_msg["content"] = processed_content_parts
return new_msg

else: # content_format == "string"
# String format: flatten to text only (for templates like DeepSeek)
text_parts = []
for chunk in msg_dict["content"]:
if isinstance(chunk, dict) and chunk.get("type") == "text":
text_parts.append(chunk["text"])
# Note: For string format, we ignore images/audio since the template
# doesn't expect structured content - multimodal placeholders would
# need to be inserted differently

new_msg = msg_dict.copy()
new_msg["content"] = " ".join(text_parts) if text_parts else ""
new_msg = {k: v for k, v in new_msg.items() if v is not None}
return new_msg
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TestFile:
TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_openai_adapter.py", 1),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("test_penalty.py", 41),
Expand Down
Loading
Loading