|
6 | 6 | from sglang.srt.entrypoints.openai.protocol import Function, Tool |
7 | 7 | from sglang.srt.function_call.base_format_detector import BaseFormatDetector |
8 | 8 | from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector |
9 | | -from sglang.srt.function_call.glm45_detector import Glm4MoeDetector |
| 9 | +from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector |
10 | 10 | from sglang.srt.function_call.kimik2_detector import KimiK2Detector |
11 | 11 | from sglang.srt.function_call.llama32_detector import Llama32Detector |
12 | 12 | from sglang.srt.function_call.mistral_detector import MistralDetector |
@@ -1485,5 +1485,145 @@ def test_parse_streaming_multiple_tool_calls_with_multi_token_chunk(self): |
1485 | 1485 | self.assertEqual(params2["city"], "Beijing") |
1486 | 1486 |
|
1487 | 1487 |
|
| 1488 | +class TestGlm4MoeDetector(unittest.TestCase): |
| 1489 | + def setUp(self): |
| 1490 | + self.tools = [ |
| 1491 | + Tool( |
| 1492 | + type="function", |
| 1493 | + function=Function( |
| 1494 | + name="get_weather", |
| 1495 | + description="Get weather information", |
| 1496 | + parameters={ |
| 1497 | + "type": "object", |
| 1498 | + "properties": { |
| 1499 | + "city": {"type": "string", "description": "City name"}, |
| 1500 | + "date": {"type": "string", "description": "Date"}, |
| 1501 | + }, |
| 1502 | + "required": ["city", "date"], |
| 1503 | + }, |
| 1504 | + ), |
| 1505 | + ), |
| 1506 | + ] |
| 1507 | + from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector |
| 1508 | + self.detector = Glm4MoeDetector() |
| 1509 | + |
| 1510 | + def test_single_tool_call(self): |
| 1511 | + text = ( |
| 1512 | + "<tool_call>get_weather\n" |
| 1513 | + "<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n" |
| 1514 | + "<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n" |
| 1515 | + "</tool_call>" |
| 1516 | + ) |
| 1517 | + result = self.detector.detect_and_parse(text, self.tools) |
| 1518 | + self.assertEqual(len(result.calls), 1) |
| 1519 | + self.assertEqual(result.calls[0].name, "get_weather") |
| 1520 | + self.assertEqual(result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}') |
| 1521 | + self.assertEqual(result.normal_text, "") |
| 1522 | + |
| 1523 | + def test_multiple_tool_calls(self): |
| 1524 | + text = ( |
| 1525 | + "<tool_call>get_weather\n" |
| 1526 | + "<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n" |
| 1527 | + "<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n" |
| 1528 | + "</tool_call>" |
| 1529 | + "<tool_call>get_weather\n" |
| 1530 | + "<arg_key>city</arg_key>\n<arg_value>Shanghai</arg_value>\n" |
| 1531 | + "<arg_key>date</arg_key>\n<arg_value>2024-06-28</arg_value>\n" |
| 1532 | + "</tool_call>" |
| 1533 | + ) |
| 1534 | + result = self.detector.detect_and_parse(text, self.tools) |
| 1535 | + self.assertEqual(len(result.calls), 2) |
| 1536 | + self.assertEqual(result.calls[0].name, "get_weather") |
| 1537 | + self.assertEqual(result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}') |
| 1538 | + self.assertEqual(result.calls[1].name, "get_weather") |
| 1539 | + self.assertEqual(result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}') |
| 1540 | + self.assertEqual(result.normal_text, "") |
| 1541 | + |
| 1542 | + def test_streaming_tool_call(self): |
| 1543 | + """Test streaming incremental parsing of a tool call.""" |
| 1544 | + chunks = [ |
| 1545 | + "<tool_call>get_weather\n", |
| 1546 | + "<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n", |
| 1547 | + "<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n", |
| 1548 | + "</tool_call>", |
| 1549 | + ] |
| 1550 | + tool_calls = [] |
| 1551 | + for chunk in chunks: |
| 1552 | + result = self.detector.parse_streaming_increment(chunk, self.tools) |
| 1553 | + for tool_call_chunk in result.calls: |
| 1554 | + if hasattr(tool_call_chunk, "tool_index") and tool_call_chunk.tool_index is not None: |
| 1555 | + while len(tool_calls) <= tool_call_chunk.tool_index: |
| 1556 | + tool_calls.append({"name": "", "parameters": {}}) |
| 1557 | + tc = tool_calls[tool_call_chunk.tool_index] |
| 1558 | + if tool_call_chunk.name: |
| 1559 | + tc["name"] = tool_call_chunk.name |
| 1560 | + if tool_call_chunk.parameters: |
| 1561 | + tc["parameters"] = tool_call_chunk.parameters |
| 1562 | + self.assertEqual(len(tool_calls), 1) |
| 1563 | + self.assertEqual(tool_calls[0]["name"], "get_weather") |
| 1564 | + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}') |
| 1565 | + |
| 1566 | + def test_streaming_multiple_tool_calls(self): |
| 1567 | + """Test streaming incremental parsing of multiple tool calls.""" |
| 1568 | + chunks = [ |
| 1569 | + "<tool_call>get_weather\n", |
| 1570 | + "<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n", |
| 1571 | + "<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n", |
| 1572 | + "</tool_call><tool_call>get_weather\n", |
| 1573 | + "<arg_key>city</arg_key>\n<arg_value>Shanghai</arg_value>\n", |
| 1574 | + "<arg_key>date</arg_key>\n<arg_value>2024-06-28</arg_value>\n", |
| 1575 | + "</tool_call>", |
| 1576 | + ] |
| 1577 | + tool_calls = [] |
| 1578 | + for chunk in chunks: |
| 1579 | + result = self.detector.parse_streaming_increment(chunk, self.tools) |
| 1580 | + for tool_call_chunk in result.calls: |
| 1581 | + if hasattr(tool_call_chunk, "tool_index") and tool_call_chunk.tool_index is not None: |
| 1582 | + while len(tool_calls) <= tool_call_chunk.tool_index: |
| 1583 | + tool_calls.append({"name": "", "parameters": {}}) |
| 1584 | + tc = tool_calls[tool_call_chunk.tool_index] |
| 1585 | + if tool_call_chunk.name: |
| 1586 | + tc["name"] = tool_call_chunk.name |
| 1587 | + if tool_call_chunk.parameters: |
| 1588 | + tc["parameters"] = tool_call_chunk.parameters |
| 1589 | + self.assertEqual(len(tool_calls), 2) |
| 1590 | + self.assertEqual(tool_calls[0]["name"], "get_weather") |
| 1591 | + self.assertEqual(tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}') |
| 1592 | + self.assertEqual(tool_calls[1]["name"], "get_weather") |
| 1593 | + self.assertEqual(tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}') |
| 1594 | + |
| 1595 | + def test_tool_call_completion(self): |
| 1596 | + """Test that the buffer and state are reset after a tool call is completed.""" |
| 1597 | + chunks = [ |
| 1598 | + "<tool_call>get_weather\n", |
| 1599 | + "<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n", |
| 1600 | + "<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n", |
| 1601 | + "</tool_call>", |
| 1602 | + ] |
| 1603 | + for chunk in chunks: |
| 1604 | + result = self.detector.parse_streaming_increment(chunk, self.tools) |
| 1605 | + self.assertEqual(self.detector.current_tool_id, 1) |
| 1606 | + |
| 1607 | + def test_invalid_tool_call(self): |
| 1608 | + """Test that invalid tool calls are handled correctly.""" |
| 1609 | + text = '<tool_call>invalid_func\n<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n</tool_call>' |
| 1610 | + result = self.detector.detect_and_parse(text, self.tools) |
| 1611 | + self.assertEqual(len(result.calls), 0) |
| 1612 | + |
| 1613 | + def test_partial_tool_call(self): |
| 1614 | + """Test parsing a partial tool call that spans multiple chunks.""" |
| 1615 | + text1 = "<tool_call>get_weather\n<arg_key>city</arg_key>\n" |
| 1616 | + result1 = self.detector.parse_streaming_increment(text1, self.tools) |
| 1617 | + self.assertEqual(result1.normal_text, "") |
| 1618 | + self.assertEqual(result1.calls, []) |
| 1619 | + self.assertEqual(self.detector._buffer, text1) |
| 1620 | + text2 = "<arg_value>Beijing</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>" |
| 1621 | + result2 = self.detector.parse_streaming_increment(text2, self.tools) |
| 1622 | + self.assertEqual(len(result2.calls), 1) |
| 1623 | + self.assertEqual(result2.calls[0].name, "get_weather") |
| 1624 | + self.assertEqual(result2.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}') |
| 1625 | + self.assertEqual(self.detector._buffer, "") |
| 1626 | + |
| 1627 | + |
1488 | 1628 | if __name__ == "__main__": |
1489 | 1629 | unittest.main() |
0 commit comments