diff --git a/src/server/main.py b/src/server/main.py index 6f6558a..da0ff35 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -154,31 +154,83 @@ async def global_exception_handler(request: Request, exc: Exception): # Tool calling helpers #===============================================================# +def _extract_hermes_tool_call_payloads(text: str) -> List[str]: + open_tag = "" + close_tag = "" + payloads: List[str] = [] + cursor = 0 + + while True: + start = text.find(open_tag, cursor) + if start < 0: + break + + payload_start = start + len(open_tag) + end = text.find(close_tag, payload_start) + if end < 0: + # Hermes parsers accept an open tool call until EOS. + payload = text[payload_start:].strip() + if payload: + payloads.append(payload) + break + + payload = text[payload_start:end].strip() + if payload: + payloads.append(payload) + + cursor = end + len(close_tag) + + return payloads + + +def _format_tool_call_arguments(arguments: Any) -> str: + if isinstance(arguments, str): + try: + return json.dumps(json.loads(arguments)) + except json.JSONDecodeError: + return arguments + return json.dumps(arguments) + + def parse_tool_calls(text: str) -> Optional[List[Dict[str, Any]]]: - # Find all potential JSON objects - pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}' - matches = re.findall(pattern, text, re.DOTALL) - - if not matches: - return None - - tool_calls = [] - for match in matches: + tool_calls: List[Dict[str, Any]] = [] + + # Hermes format: {...}, with EOS fallback for missing close tag. + for payload in _extract_hermes_tool_call_payloads(text): + try: + data = json.loads(payload) + if isinstance(data, dict) and "name" in data and "arguments" in data: + tool_calls.append({ + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": str(data.get("name", "")), + "arguments": _format_tool_call_arguments(data.get("arguments", {})), + }, + }) + except json.JSONDecodeError: + continue + + if tool_calls: + return tool_calls + + # Backward compatibility for plain JSON tool call outputs without tags. + pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}" + for match in re.findall(pattern, text, re.DOTALL): try: data = json.loads(match) - # Check if it has the expected structure - if "name" in data and "arguments" in data: + if isinstance(data, dict) and "name" in data and "arguments" in data: tool_calls.append({ "id": f"call_{uuid.uuid4().hex[:24]}", "type": "function", "function": { - "name": data.get("name", ""), - "arguments": json.dumps(data.get("arguments", {})) - } + "name": str(data.get("name", "")), + "arguments": _format_tool_call_arguments(data.get("arguments", {})), + }, }) except json.JSONDecodeError: - pass - + continue + return tool_calls if tool_calls else None #===============================================================# @@ -300,6 +352,7 @@ async def event_stream() -> AsyncIterator[bytes]: accumulated_text = "" metrics_data = None tool_call_sent = False + tool_call_started = False cancel_request_id = None try: @@ -320,6 +373,13 @@ async def event_stream() -> AsyncIterator[bytes]: continue accumulated_text += item + if not tool_call_started: + tool_call_started = ( + "" in accumulated_text + or " AsyncIterator[bytes]: }] } yield (f"data: {json.dumps(tool_call_args)}\n\n").encode() - elif not tool_calls: + elif not tool_calls and not tool_call_started: # Regular content streaming chunk_payload = { "id": request_id, diff --git a/src/tests/test_tool_call_parser_unit.py b/src/tests/test_tool_call_parser_unit.py new file mode 100644 index 0000000..5d5daf0 --- /dev/null +++ b/src/tests/test_tool_call_parser_unit.py @@ -0,0 +1,135 @@ +import json +from typing import Any, AsyncIterator, Dict, List + +import pytest # type: ignore[import] +from fastapi.responses import StreamingResponse + +import src.server.main as server_main +from src.server.models.requests_openai import OpenAIChatCompletionRequest + + +class _DummyRequest: + async def is_disconnected(self) -> bool: + return False + + +def _extract_sse_payloads(chunks: List[bytes]) -> List[str]: + payloads: List[str] = [] + for chunk in chunks: + for line in chunk.decode().splitlines(): + if line.startswith("data: "): + payloads.append(line[6:]) + return payloads + + +def test_parse_tool_calls_supports_hermes_tool_call_tags() -> None: + text = ( + "" + '{"name":"search","arguments":{"query":"OpenVINO"}}' + "" + ) + + tool_calls = server_main.parse_tool_calls(text) + + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["type"] == "function" + assert tool_calls[0]["function"]["name"] == "search" + assert json.loads(tool_calls[0]["function"]["arguments"]) == {"query": "OpenVINO"} + + +def test_parse_tool_calls_supports_missing_closing_tag_until_eos() -> None: + text = '{"name":"search","arguments":{"query":"vLLM"}}' + + tool_calls = server_main.parse_tool_calls(text) + + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["name"] == "search" + assert json.loads(tool_calls[0]["function"]["arguments"]) == {"query": "vLLM"} + + +@pytest.mark.asyncio +async def test_openai_chat_completions_non_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: + class _Workers: + async def generate(self, model_name: str, generation_config: Any) -> Dict[str, Any]: + return { + "text": ( + "" + '{"name":"search","arguments":{"query":"OpenArc"}}' + "" + ), + "metrics": {"input_token": 4, "new_token": 6, "total_token": 10}, + } + + monkeypatch.setattr(server_main, "_workers", _Workers()) + + request = OpenAIChatCompletionRequest( + model="demo-model", + messages=[{"role": "user", "content": "Find OpenArc docs"}], + stream=False, + ) + + response = await server_main.openai_chat_completions(request, _DummyRequest()) + + choice = response["choices"][0] + assert choice["finish_reason"] == "tool_calls" + assert choice["message"]["content"] is None + assert len(choice["message"]["tool_calls"]) == 1 + assert choice["message"]["tool_calls"][0]["function"]["name"] == "search" + assert json.loads(choice["message"]["tool_calls"][0]["function"]["arguments"]) == { + "query": "OpenArc" + } + + +@pytest.mark.asyncio +async def test_openai_chat_completions_streaming_hermes_tool_call(monkeypatch: pytest.MonkeyPatch) -> None: + class _Workers: + async def stream_generate(self, model_name: str, generation_config: Any) -> AsyncIterator[Any]: + yield "{"name":"search","arguments":{"query":"OpenArc"}}' + yield "" + yield {"metrics": {"input_token": 2, "new_token": 3, "total_token": 5}} + + async def infer_cancel(self, request_id: str) -> None: + return None + + monkeypatch.setattr(server_main, "_workers", _Workers()) + + request = OpenAIChatCompletionRequest( + model="demo-model", + messages=[{"role": "user", "content": "Find OpenArc docs"}], + stream=True, + ) + + response = await server_main.openai_chat_completions(request, _DummyRequest()) + assert isinstance(response, StreamingResponse) + + chunks: List[bytes] = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + payloads = _extract_sse_payloads(chunks) + assert payloads[-1] == "[DONE]" + + json_payloads = [json.loads(p) for p in payloads if p != "[DONE]"] + + content_deltas = [ + payload + for payload in json_payloads + if payload["choices"][0]["delta"].get("content") + ] + assert content_deltas == [] + + tool_deltas = [ + payload + for payload in json_payloads + if payload["choices"][0]["delta"].get("tool_calls") + ] + assert len(tool_deltas) >= 2 + assert tool_deltas[0]["choices"][0]["delta"]["tool_calls"][0]["function"]["name"] == "search" + assert json.loads( + tool_deltas[1]["choices"][0]["delta"]["tool_calls"][0]["function"]["arguments"] + ) == {"query": "OpenArc"} + + assert json_payloads[-1]["choices"][0]["finish_reason"] == "tool_calls"