diff --git a/src/server/main.py b/src/server/main.py
index 6f6558a..da0ff35 100644
--- a/src/server/main.py
+++ b/src/server/main.py
@@ -154,31 +154,83 @@ async def global_exception_handler(request: Request, exc: Exception):
# Tool calling helpers
#===============================================================#
+def _extract_hermes_tool_call_payloads(text: str) -> List[str]:
+ open_tag = ""
+ close_tag = ""
+ payloads: List[str] = []
+ cursor = 0
+
+ while True:
+ start = text.find(open_tag, cursor)
+ if start < 0:
+ break
+
+ payload_start = start + len(open_tag)
+ end = text.find(close_tag, payload_start)
+ if end < 0:
+ # Hermes parsers accept an open tool call until EOS.
+ payload = text[payload_start:].strip()
+ if payload:
+ payloads.append(payload)
+ break
+
+ payload = text[payload_start:end].strip()
+ if payload:
+ payloads.append(payload)
+
+ cursor = end + len(close_tag)
+
+ return payloads
+
+
+def _format_tool_call_arguments(arguments: Any) -> str:
+ if isinstance(arguments, str):
+ try:
+ return json.dumps(json.loads(arguments))
+ except json.JSONDecodeError:
+ return arguments
+ return json.dumps(arguments)
+
+
def parse_tool_calls(text: str) -> Optional[List[Dict[str, Any]]]:
- # Find all potential JSON objects
- pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}'
- matches = re.findall(pattern, text, re.DOTALL)
-
- if not matches:
- return None
-
- tool_calls = []
- for match in matches:
+ tool_calls: List[Dict[str, Any]] = []
+
+ # Hermes format: {...}, with EOS fallback for missing close tag.
+ for payload in _extract_hermes_tool_call_payloads(text):
+ try:
+ data = json.loads(payload)
+ if isinstance(data, dict) and "name" in data and "arguments" in data:
+ tool_calls.append({
+ "id": f"call_{uuid.uuid4().hex[:24]}",
+ "type": "function",
+ "function": {
+ "name": str(data.get("name", "")),
+ "arguments": _format_tool_call_arguments(data.get("arguments", {})),
+ },
+ })
+ except json.JSONDecodeError:
+ continue
+
+ if tool_calls:
+ return tool_calls
+
+ # Backward compatibility for plain JSON tool call outputs without tags.
+ pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}"
+ for match in re.findall(pattern, text, re.DOTALL):
try:
data = json.loads(match)
- # Check if it has the expected structure
- if "name" in data and "arguments" in data:
+ if isinstance(data, dict) and "name" in data and "arguments" in data:
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
- "name": data.get("name", ""),
- "arguments": json.dumps(data.get("arguments", {}))
- }
+ "name": str(data.get("name", "")),
+ "arguments": _format_tool_call_arguments(data.get("arguments", {})),
+ },
})
except json.JSONDecodeError:
- pass
-
+ continue
+
return tool_calls if tool_calls else None
#===============================================================#
@@ -300,6 +352,7 @@ async def event_stream() -> AsyncIterator[bytes]:
accumulated_text = ""
metrics_data = None
tool_call_sent = False
+ tool_call_started = False
cancel_request_id = None
try:
@@ -320,6 +373,13 @@ async def event_stream() -> AsyncIterator[bytes]:
continue
accumulated_text += item
+ if not tool_call_started:
+ tool_call_started = (
+ "" in accumulated_text
+ or " AsyncIterator[bytes]:
}]
}
yield (f"data: {json.dumps(tool_call_args)}\n\n").encode()
- elif not tool_calls:
+ elif not tool_calls and not tool_call_started:
# Regular content streaming
chunk_payload = {
"id": request_id,
diff --git a/src/tests/test_tool_call_parser_unit.py b/src/tests/test_tool_call_parser_unit.py
new file mode 100644
index 0000000..5d5daf0
--- /dev/null
+++ b/src/tests/test_tool_call_parser_unit.py
@@ -0,0 +1,135 @@
+import json
+from typing import Any, AsyncIterator, Dict, List
+
+import pytest # type: ignore[import]
+from fastapi.responses import StreamingResponse
+
+import src.server.main as server_main
+from src.server.models.requests_openai import OpenAIChatCompletionRequest
+
+
+class _DummyRequest:
+ async def is_disconnected(self) -> bool:
+ return False
+
+
+def _extract_sse_payloads(chunks: List[bytes]) -> List[str]:
+ payloads: List[str] = []
+ for chunk in chunks:
+ for line in chunk.decode().splitlines():
+ if line.startswith("data: "):
+ payloads.append(line[6:])
+ return payloads
+
+
+def test_parse_tool_calls_supports_hermes_tool_call_tags() -> None:
+ text = (
+ ""
+ '{"name":"search","arguments":{"query":"OpenVINO"}}'
+ ""
+ )
+
+ tool_calls = server_main.parse_tool_calls(text)
+
+ assert tool_calls is not None
+ assert len(tool_calls) == 1
+ assert tool_calls[0]["type"] == "function"
+ assert tool_calls[0]["function"]["name"] == "search"
+ assert json.loads(tool_calls[0]["function"]["arguments"]) == {"query": "OpenVINO"}
+
+
+def test_parse_tool_calls_supports_missing_closing_tag_until_eos() -> None:
+ text = '{"name":"search","arguments":{"query":"vLLM"}}'
+
+ tool_calls = server_main.parse_tool_calls(text)
+
+ assert tool_calls is not None
+ assert len(tool_calls) == 1
+ assert tool_calls[0]["function"]["name"] == "search"
+ assert json.loads(tool_calls[0]["function"]["arguments"]) == {"query": "vLLM"}
+
+
+@pytest.mark.asyncio
+async def test_openai_chat_completions_non_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
+ class _Workers:
+ async def generate(self, model_name: str, generation_config: Any) -> Dict[str, Any]:
+ return {
+ "text": (
+ ""
+ '{"name":"search","arguments":{"query":"OpenArc"}}'
+ ""
+ ),
+ "metrics": {"input_token": 4, "new_token": 6, "total_token": 10},
+ }
+
+ monkeypatch.setattr(server_main, "_workers", _Workers())
+
+ request = OpenAIChatCompletionRequest(
+ model="demo-model",
+ messages=[{"role": "user", "content": "Find OpenArc docs"}],
+ stream=False,
+ )
+
+ response = await server_main.openai_chat_completions(request, _DummyRequest())
+
+ choice = response["choices"][0]
+ assert choice["finish_reason"] == "tool_calls"
+ assert choice["message"]["content"] is None
+ assert len(choice["message"]["tool_calls"]) == 1
+ assert choice["message"]["tool_calls"][0]["function"]["name"] == "search"
+ assert json.loads(choice["message"]["tool_calls"][0]["function"]["arguments"]) == {
+ "query": "OpenArc"
+ }
+
+
+@pytest.mark.asyncio
+async def test_openai_chat_completions_streaming_hermes_tool_call(monkeypatch: pytest.MonkeyPatch) -> None:
+ class _Workers:
+ async def stream_generate(self, model_name: str, generation_config: Any) -> AsyncIterator[Any]:
+ yield "{"name":"search","arguments":{"query":"OpenArc"}}'
+ yield ""
+ yield {"metrics": {"input_token": 2, "new_token": 3, "total_token": 5}}
+
+ async def infer_cancel(self, request_id: str) -> None:
+ return None
+
+ monkeypatch.setattr(server_main, "_workers", _Workers())
+
+ request = OpenAIChatCompletionRequest(
+ model="demo-model",
+ messages=[{"role": "user", "content": "Find OpenArc docs"}],
+ stream=True,
+ )
+
+ response = await server_main.openai_chat_completions(request, _DummyRequest())
+ assert isinstance(response, StreamingResponse)
+
+ chunks: List[bytes] = []
+ async for chunk in response.body_iterator:
+ chunks.append(chunk)
+
+ payloads = _extract_sse_payloads(chunks)
+ assert payloads[-1] == "[DONE]"
+
+ json_payloads = [json.loads(p) for p in payloads if p != "[DONE]"]
+
+ content_deltas = [
+ payload
+ for payload in json_payloads
+ if payload["choices"][0]["delta"].get("content")
+ ]
+ assert content_deltas == []
+
+ tool_deltas = [
+ payload
+ for payload in json_payloads
+ if payload["choices"][0]["delta"].get("tool_calls")
+ ]
+ assert len(tool_deltas) >= 2
+ assert tool_deltas[0]["choices"][0]["delta"]["tool_calls"][0]["function"]["name"] == "search"
+ assert json.loads(
+ tool_deltas[1]["choices"][0]["delta"]["tool_calls"][0]["function"]["arguments"]
+ ) == {"query": "OpenArc"}
+
+ assert json_payloads[-1]["choices"][0]["finish_reason"] == "tool_calls"