Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 103 additions & 82 deletions python/sglang/srt/function_call/base_format_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self):
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
self.tool_call_separator = ", "

def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
Expand All @@ -50,7 +51,7 @@ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
Expand Down Expand Up @@ -106,7 +107,17 @@ def parse_streaming_increment(
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):

# The current_text has tool_call if it is the start of a new tool call sequence
# or it is the start of a new tool call after a tool call separator, when there is a previous tool call
if not (
self.bot_token in current_text
or current_text.startswith("{")
or (
self.current_tool_id > 0
and current_text.startswith(self.tool_call_separator + "{")
)
):
# Only clear buffer if we're sure no tool call is starting
if not self._ends_with_partial_token(self._buffer, self.bot_token):
normal_text = self._buffer
Expand All @@ -127,91 +138,73 @@ def parse_streaming_increment(
}

flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR

try:
tool_call_arr = []
is_complete = []
try:
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
if current_text.startswith(self.bot_token):
start_idx = len(self.bot_token)
elif self.current_tool_id > 0 and current_text.startswith(
self.tool_call_separator
):
start_idx = len(self.tool_call_separator)
else:
start_idx = 0

if start_idx >= len(current_text):
return StreamingParseResult()

(obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags)

is_current_complete = _is_complete_json(
current_text[start_idx : start_idx + end_idx]
)
while start_idx < len(current_text):
(obj, end_idx) = _partial_json_loads(
current_text[start_idx:], flags
)
is_complete.append(
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")

# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()

# Handle parameters/arguments consistency
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()

# Handle parameters/arguments consistency
# NOTE: we assume here that the obj is always partial of a single tool call
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]

current_tool_call = obj

except MalformedJSON:
return StreamingParseResult()

if len(tool_call_arr) == 0:
if not current_tool_call:
return StreamingParseResult()

current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)

# Handle new tool in array
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]

res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
else:
res = StreamingParseResult()
else:
res = StreamingParseResult()

self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return res

# Handle tool name
elif not self.current_tool_name_sent:
# Case 1: Handle tool name streaming
# This happens when we encounter a tool but haven't sent its name yet
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")

if function_name and function_name in self._tool_indices:
# If this is a new tool (current_tool_id was -1), initialize it
if self.current_tool_id == -1:
self.current_tool_id = 0
self.streamed_args_for_tool.append("")
# If this is a subsequent tool, ensure streamed_args_for_tool is large enough
elif self.current_tool_id >= len(self.streamed_args_for_tool):
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")

# Send the tool name with empty parameters
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self._tool_indices[function_name],
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
Expand All @@ -221,47 +214,75 @@ def parse_streaming_increment(
else:
res = StreamingParseResult()

# Handle streaming arguments
# Case 2: Handle streaming arguments
# This happens when we've already sent the tool name and now need to stream arguments incrementally
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()

if cur_arguments:
# Calculate how much of the arguments we've already streamed
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
prev_arguments = None
if self.current_tool_id < len(self.prev_tool_call_arr):
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id
].get("arguments")

argument_diff = None
if is_complete[self.current_tool_id]:

# If the current tool's JSON is complete, send all remaining arguments
if is_current_complete:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
completing_tool_id = (
self.current_tool_id
) # Save the ID of the tool that's completing

# Only remove the processed portion, keep unprocessed content
self._buffer = current_text[start_idx + end_idx :]

if self.current_tool_id < len(self.prev_tool_call_arr):
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
self.current_tool_id += 1

# If the tool is still being parsed, send incremental changes
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]

# Send the argument diff if there's something new
if argument_diff is not None:
# Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing
tool_index_to_use = (
completing_tool_id
if is_current_complete
else self.current_tool_id
)
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
tool_index=tool_index_to_use,
parameters=argument_diff,
)
],
)
if not is_complete[self.current_tool_id]:
if not is_current_complete:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff

self.prev_tool_call_arr = tool_call_arr
# Update prev_tool_call_arr with current state
if self.current_tool_id >= 0:
# Ensure prev_tool_call_arr is large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call

return res

except Exception as e:
Expand Down
13 changes: 11 additions & 2 deletions python/sglang/srt/function_call/llama32_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class Llama32Detector(BaseFormatDetector):
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
# NOTE: technically Llama3.2 doesn't support well with parallel tool calls
# They need specific prompt engineering to support parallel tool calls
# Here we use ';' as the separator, which might have compatibility issues
# if users define to use a different separator in their prompt
self.tool_call_separator = ";"

def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
Expand All @@ -42,7 +47,11 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult
normal_text, action_text = "", text

# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
json_parts = [
part.strip()
for part in action_text.split(self.tool_call_separator)
if part.strip()
]
all_actions = []
for part in json_parts:
try:
Expand Down Expand Up @@ -70,5 +79,5 @@ def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
function_format="json",
tool_call_separator=",",
tool_call_separator=self.tool_call_separator,
)
3 changes: 2 additions & 1 deletion python/sglang/srt/function_call/mistral_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self):
self.bot_token = "[TOOL_CALLS] ["
self.eot_token = "]"
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
self.tool_call_separator = ", "

def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Mistral format tool call."""
Expand Down Expand Up @@ -126,5 +127,5 @@ def build_ebnf(self, tools: List[Tool]):
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
function_format="json",
tool_call_separator=", ",
tool_call_separator=self.tool_call_separator,
)
2 changes: 1 addition & 1 deletion python/sglang/srt/function_call/qwen25_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self):
super().__init__()
self.bot_token = "<tool_call>\n"
self.eot_token = "\n</tool_call>"
self.tool_call_separator = "\n"
self._normal_text_buffer = "" # Buffer for handling partial end tokens

def has_tool_call(self, text: str) -> bool:
Expand Down Expand Up @@ -104,7 +105,6 @@ def parse_streaming_increment(
return result

def structure_info(self) -> _GetInfoFunc:
# TODO: Update the begin and end tokens with '\n' if necessary
return lambda name: StructureInfo(
begin='<tool_call>\n{"name":"' + name + '", "arguments":',
end="}\n</tool_call>",
Expand Down
17 changes: 17 additions & 0 deletions python/sglang/srt/function_call/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ def _find_common_prefix(s1: str, s2: str) -> str:


def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
Parse incomplete or partial JSON strings commonly encountered during streaming.

Args:
input_str (str): The potentially incomplete JSON string to parse.
flags (Allow): Bitwise flags controlling what types of partial data are allowed.
Common flags include:
- Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
- Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
- Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
- Allow.ALL: Allow all types of partial data

Returns:
Tuple[Any, int]: A tuple containing:
- parsed_object: The Python object parsed from the JSON
- consumed_length: Number of characters consumed from input_str
"""
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,6 @@ def v1_chat_generate_response(
tool_calls = [
ToolCall(
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
index=call_info.tool_index,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Good catch removing the index field here for non-streaming tool calls! This aligns the response with the OpenAI API specification, which does not include an index field for tool_calls items in the chat completion object. This enhances API compliance.

Could you confirm if this index was indeed unused or potentially misleading for consumers expecting strict OpenAI compatibility for non-streaming responses?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unused. See: https://platform.openai.com/docs/api-reference/chat/object
there is no index here.
Screenshot 2025-05-28 at 11 24 09 AM

function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters
),
Expand Down
Loading
Loading