Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 77 additions & 17 deletions src/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,31 +154,83 @@ async def global_exception_handler(request: Request, exc: Exception):
# Tool calling helpers
#===============================================================#

def _extract_hermes_tool_call_payloads(text: str) -> List[str]:
open_tag = "<tool_call>"
close_tag = "</tool_call>"
payloads: List[str] = []
cursor = 0

while True:
start = text.find(open_tag, cursor)
if start < 0:
break

payload_start = start + len(open_tag)
end = text.find(close_tag, payload_start)
if end < 0:
# Hermes parsers accept an open tool call until EOS.
payload = text[payload_start:].strip()
if payload:
payloads.append(payload)
break

payload = text[payload_start:end].strip()
if payload:
payloads.append(payload)

cursor = end + len(close_tag)

return payloads


def _format_tool_call_arguments(arguments: Any) -> str:
if isinstance(arguments, str):
try:
return json.dumps(json.loads(arguments))
except json.JSONDecodeError:
return arguments
return json.dumps(arguments)


def parse_tool_calls(text: str) -> Optional[List[Dict[str, Any]]]:
# Find all potential JSON objects
pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}'
matches = re.findall(pattern, text, re.DOTALL)

if not matches:
return None

tool_calls = []
for match in matches:
tool_calls: List[Dict[str, Any]] = []

# Hermes format: <tool_call>{...}</tool_call>, with EOS fallback for missing close tag.
for payload in _extract_hermes_tool_call_payloads(text):
try:
data = json.loads(payload)
if isinstance(data, dict) and "name" in data and "arguments" in data:
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": str(data.get("name", "")),
"arguments": _format_tool_call_arguments(data.get("arguments", {})),
},
})
except json.JSONDecodeError:
continue

if tool_calls:
return tool_calls

# Backward compatibility for plain JSON tool call outputs without tags.
pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}"
for match in re.findall(pattern, text, re.DOTALL):
try:
data = json.loads(match)
# Check if it has the expected structure
if "name" in data and "arguments" in data:
if isinstance(data, dict) and "name" in data and "arguments" in data:
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": data.get("name", ""),
"arguments": json.dumps(data.get("arguments", {}))
}
"name": str(data.get("name", "")),
"arguments": _format_tool_call_arguments(data.get("arguments", {})),
},
})
except json.JSONDecodeError:
pass
continue

return tool_calls if tool_calls else None

#===============================================================#
Expand Down Expand Up @@ -300,6 +352,7 @@ async def event_stream() -> AsyncIterator[bytes]:
accumulated_text = ""
metrics_data = None
tool_call_sent = False
tool_call_started = False
cancel_request_id = None

try:
Expand All @@ -320,6 +373,13 @@ async def event_stream() -> AsyncIterator[bytes]:
continue

accumulated_text += item
if not tool_call_started:
tool_call_started = (
"<tool_call>" in accumulated_text
or "<tool_call" in accumulated_text
or "<tool_" in accumulated_text
)

tool_calls = parse_tool_calls(accumulated_text)

# If tool call detected and not yet sent, stream tool call deltas
Expand Down Expand Up @@ -366,7 +426,7 @@ async def event_stream() -> AsyncIterator[bytes]:
}]
}
yield (f"data: {json.dumps(tool_call_args)}\n\n").encode()
elif not tool_calls:
elif not tool_calls and not tool_call_started:
# Regular content streaming
chunk_payload = {
"id": request_id,
Expand Down
135 changes: 135 additions & 0 deletions src/tests/test_tool_call_parser_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import json
from typing import Any, AsyncIterator, Dict, List

import pytest # type: ignore[import]
from fastapi.responses import StreamingResponse

import src.server.main as server_main
from src.server.models.requests_openai import OpenAIChatCompletionRequest


class _DummyRequest:
async def is_disconnected(self) -> bool:
return False


def _extract_sse_payloads(chunks: List[bytes]) -> List[str]:
payloads: List[str] = []
for chunk in chunks:
for line in chunk.decode().splitlines():
if line.startswith("data: "):
payloads.append(line[6:])
return payloads


def test_parse_tool_calls_supports_hermes_tool_call_tags() -> None:
text = (
"<tool_call>"
'{"name":"search","arguments":{"query":"OpenVINO"}}'
"</tool_call>"
)

tool_calls = server_main.parse_tool_calls(text)

assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0]["type"] == "function"
assert tool_calls[0]["function"]["name"] == "search"
assert json.loads(tool_calls[0]["function"]["arguments"]) == {"query": "OpenVINO"}


def test_parse_tool_calls_supports_missing_closing_tag_until_eos() -> None:
text = '<tool_call>{"name":"search","arguments":{"query":"vLLM"}}'

tool_calls = server_main.parse_tool_calls(text)

assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0]["function"]["name"] == "search"
assert json.loads(tool_calls[0]["function"]["arguments"]) == {"query": "vLLM"}


@pytest.mark.asyncio
async def test_openai_chat_completions_non_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
class _Workers:
async def generate(self, model_name: str, generation_config: Any) -> Dict[str, Any]:
return {
"text": (
"<tool_call>"
'{"name":"search","arguments":{"query":"OpenArc"}}'
"</tool_call>"
),
"metrics": {"input_token": 4, "new_token": 6, "total_token": 10},
}

monkeypatch.setattr(server_main, "_workers", _Workers())

request = OpenAIChatCompletionRequest(
model="demo-model",
messages=[{"role": "user", "content": "Find OpenArc docs"}],
stream=False,
)

response = await server_main.openai_chat_completions(request, _DummyRequest())

choice = response["choices"][0]
assert choice["finish_reason"] == "tool_calls"
assert choice["message"]["content"] is None
assert len(choice["message"]["tool_calls"]) == 1
assert choice["message"]["tool_calls"][0]["function"]["name"] == "search"
assert json.loads(choice["message"]["tool_calls"][0]["function"]["arguments"]) == {
"query": "OpenArc"
}


@pytest.mark.asyncio
async def test_openai_chat_completions_streaming_hermes_tool_call(monkeypatch: pytest.MonkeyPatch) -> None:
class _Workers:
async def stream_generate(self, model_name: str, generation_config: Any) -> AsyncIterator[Any]:
yield "<tool_"
yield 'call>{"name":"search","arguments":{"query":"OpenArc"}}'
yield "</tool_call>"
yield {"metrics": {"input_token": 2, "new_token": 3, "total_token": 5}}

async def infer_cancel(self, request_id: str) -> None:
return None

monkeypatch.setattr(server_main, "_workers", _Workers())

request = OpenAIChatCompletionRequest(
model="demo-model",
messages=[{"role": "user", "content": "Find OpenArc docs"}],
stream=True,
)

response = await server_main.openai_chat_completions(request, _DummyRequest())
assert isinstance(response, StreamingResponse)

chunks: List[bytes] = []
async for chunk in response.body_iterator:
chunks.append(chunk)

payloads = _extract_sse_payloads(chunks)
assert payloads[-1] == "[DONE]"

json_payloads = [json.loads(p) for p in payloads if p != "[DONE]"]

content_deltas = [
payload
for payload in json_payloads
if payload["choices"][0]["delta"].get("content")
]
assert content_deltas == []

tool_deltas = [
payload
for payload in json_payloads
if payload["choices"][0]["delta"].get("tool_calls")
]
assert len(tool_deltas) >= 2
assert tool_deltas[0]["choices"][0]["delta"]["tool_calls"][0]["function"]["name"] == "search"
assert json.loads(
tool_deltas[1]["choices"][0]["delta"]["tool_calls"][0]["function"]["arguments"]
) == {"query": "OpenArc"}

assert json_payloads[-1]["choices"][0]["finish_reason"] == "tool_calls"