Skip to content

Commit 4a90d82

Browse files
committed
Fix GLM4 moe function call parser current_tool_id not increment issue
and add unit tests
1 parent f6aaaa5 commit 4a90d82

File tree

2 files changed

+145
-5
lines changed

2 files changed

+145
-5
lines changed

python/sglang/srt/function_call/glm4_moe_detector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def __init__(self):
5252
self.func_call_regex = r'<tool_call>.*?</tool_call>'
5353
self.func_detail_regex = r'<tool_call>([^\n]*)\n(.*)</tool_call>'
5454
self.func_arg_regex = r'<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>'
55-
self._last_arguments = ""
56-
self.current_tool_id = -1
5755

5856
def has_tool_call(self, text: str) -> bool:
5957
"""Check if the text contains a glm-4.5 format tool call."""
@@ -112,7 +110,8 @@ def parse_streaming_increment(
112110
if self.current_tool_id > 0:
113111
current_text = ""
114112
return StreamingParseResult(normal_text=current_text)
115-
end = current_text.rfind(self.eot_token)
113+
# find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token)
114+
end = current_text.find(self.eot_token)
116115
if end != -1:
117116
# Initialize state if this is the first tool call
118117
if self.current_tool_id == -1:
@@ -131,8 +130,9 @@ def parse_streaming_increment(
131130
"arguments": result.calls[0].parameters
132131
}
133132
self.streamed_args_for_tool[self.current_tool_id] = result.calls[0].parameters
133+
result.calls[0].tool_index = self.current_tool_id
134+
self.current_tool_id += 1
134135
self._buffer = current_text[end + len(self.eot_token):]
135-
self.current_tool_id += 1
136136
return result
137137
normal_text = current_text[:start]
138138
self._buffer = current_text[start:]

test/srt/test_function_call_parser.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sglang.srt.entrypoints.openai.protocol import Function, Tool
77
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
88
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
9-
from sglang.srt.function_call.glm45_detector import Glm4MoeDetector
9+
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
1010
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
1111
from sglang.srt.function_call.llama32_detector import Llama32Detector
1212
from sglang.srt.function_call.mistral_detector import MistralDetector
@@ -1485,5 +1485,145 @@ def test_parse_streaming_multiple_tool_calls_with_multi_token_chunk(self):
14851485
self.assertEqual(params2["city"], "Beijing")
14861486

14871487

1488+
class TestGlm4MoeDetector(unittest.TestCase):
1489+
def setUp(self):
1490+
self.tools = [
1491+
Tool(
1492+
type="function",
1493+
function=Function(
1494+
name="get_weather",
1495+
description="Get weather information",
1496+
parameters={
1497+
"type": "object",
1498+
"properties": {
1499+
"city": {"type": "string", "description": "City name"},
1500+
"date": {"type": "string", "description": "Date"},
1501+
},
1502+
"required": ["city", "date"],
1503+
},
1504+
),
1505+
),
1506+
]
1507+
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
1508+
self.detector = Glm4MoeDetector()
1509+
1510+
def test_single_tool_call(self):
1511+
text = (
1512+
"<tool_call>get_weather\n"
1513+
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n"
1514+
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n"
1515+
"</tool_call>"
1516+
)
1517+
result = self.detector.detect_and_parse(text, self.tools)
1518+
self.assertEqual(len(result.calls), 1)
1519+
self.assertEqual(result.calls[0].name, "get_weather")
1520+
self.assertEqual(result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}')
1521+
self.assertEqual(result.normal_text, "")
1522+
1523+
def test_multiple_tool_calls(self):
1524+
text = (
1525+
"<tool_call>get_weather\n"
1526+
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n"
1527+
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n"
1528+
"</tool_call>"
1529+
"<tool_call>get_weather\n"
1530+
"<arg_key>city</arg_key>\n<arg_value>Shanghai</arg_value>\n"
1531+
"<arg_key>date</arg_key>\n<arg_value>2024-06-28</arg_value>\n"
1532+
"</tool_call>"
1533+
)
1534+
result = self.detector.detect_and_parse(text, self.tools)
1535+
self.assertEqual(len(result.calls), 2)
1536+
self.assertEqual(result.calls[0].name, "get_weather")
1537+
self.assertEqual(result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}')
1538+
self.assertEqual(result.calls[1].name, "get_weather")
1539+
self.assertEqual(result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}')
1540+
self.assertEqual(result.normal_text, "")
1541+
1542+
def test_streaming_tool_call(self):
1543+
"""Test streaming incremental parsing of a tool call."""
1544+
chunks = [
1545+
"<tool_call>get_weather\n",
1546+
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n",
1547+
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n",
1548+
"</tool_call>",
1549+
]
1550+
tool_calls = []
1551+
for chunk in chunks:
1552+
result = self.detector.parse_streaming_increment(chunk, self.tools)
1553+
for tool_call_chunk in result.calls:
1554+
if hasattr(tool_call_chunk, "tool_index") and tool_call_chunk.tool_index is not None:
1555+
while len(tool_calls) <= tool_call_chunk.tool_index:
1556+
tool_calls.append({"name": "", "parameters": {}})
1557+
tc = tool_calls[tool_call_chunk.tool_index]
1558+
if tool_call_chunk.name:
1559+
tc["name"] = tool_call_chunk.name
1560+
if tool_call_chunk.parameters:
1561+
tc["parameters"] = tool_call_chunk.parameters
1562+
self.assertEqual(len(tool_calls), 1)
1563+
self.assertEqual(tool_calls[0]["name"], "get_weather")
1564+
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}')
1565+
1566+
def test_streaming_multiple_tool_calls(self):
1567+
"""Test streaming incremental parsing of multiple tool calls."""
1568+
chunks = [
1569+
"<tool_call>get_weather\n",
1570+
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n",
1571+
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n",
1572+
"</tool_call><tool_call>get_weather\n",
1573+
"<arg_key>city</arg_key>\n<arg_value>Shanghai</arg_value>\n",
1574+
"<arg_key>date</arg_key>\n<arg_value>2024-06-28</arg_value>\n",
1575+
"</tool_call>",
1576+
]
1577+
tool_calls = []
1578+
for chunk in chunks:
1579+
result = self.detector.parse_streaming_increment(chunk, self.tools)
1580+
for tool_call_chunk in result.calls:
1581+
if hasattr(tool_call_chunk, "tool_index") and tool_call_chunk.tool_index is not None:
1582+
while len(tool_calls) <= tool_call_chunk.tool_index:
1583+
tool_calls.append({"name": "", "parameters": {}})
1584+
tc = tool_calls[tool_call_chunk.tool_index]
1585+
if tool_call_chunk.name:
1586+
tc["name"] = tool_call_chunk.name
1587+
if tool_call_chunk.parameters:
1588+
tc["parameters"] = tool_call_chunk.parameters
1589+
self.assertEqual(len(tool_calls), 2)
1590+
self.assertEqual(tool_calls[0]["name"], "get_weather")
1591+
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}')
1592+
self.assertEqual(tool_calls[1]["name"], "get_weather")
1593+
self.assertEqual(tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}')
1594+
1595+
def test_tool_call_completion(self):
1596+
"""Test that the buffer and state are reset after a tool call is completed."""
1597+
chunks = [
1598+
"<tool_call>get_weather\n",
1599+
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n",
1600+
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n",
1601+
"</tool_call>",
1602+
]
1603+
for chunk in chunks:
1604+
result = self.detector.parse_streaming_increment(chunk, self.tools)
1605+
self.assertEqual(self.detector.current_tool_id, 1)
1606+
1607+
def test_invalid_tool_call(self):
1608+
"""Test that invalid tool calls are handled correctly."""
1609+
text = '<tool_call>invalid_func\n<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n</tool_call>'
1610+
result = self.detector.detect_and_parse(text, self.tools)
1611+
self.assertEqual(len(result.calls), 0)
1612+
1613+
def test_partial_tool_call(self):
1614+
"""Test parsing a partial tool call that spans multiple chunks."""
1615+
text1 = "<tool_call>get_weather\n<arg_key>city</arg_key>\n"
1616+
result1 = self.detector.parse_streaming_increment(text1, self.tools)
1617+
self.assertEqual(result1.normal_text, "")
1618+
self.assertEqual(result1.calls, [])
1619+
self.assertEqual(self.detector._buffer, text1)
1620+
text2 = "<arg_value>Beijing</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>"
1621+
result2 = self.detector.parse_streaming_increment(text2, self.tools)
1622+
self.assertEqual(len(result2.calls), 1)
1623+
self.assertEqual(result2.calls[0].name, "get_weather")
1624+
self.assertEqual(result2.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}')
1625+
self.assertEqual(self.detector._buffer, "")
1626+
1627+
14881628
if __name__ == "__main__":
14891629
unittest.main()

0 commit comments

Comments
 (0)