Skip to content

Commit ce2ca22

Browse files
chaunceyjiangwendyliu235
authored andcommitted
[Bugfix] Fix crash when tool_choice=required exceeds max_tokens (vllm-project#36841)
Signed-off-by: chaunceyjiang <[email protected]>
1 parent f4dc8b0 commit ce2ca22

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

tests/entrypoints/openai/test_completion_with_function_calling.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,27 @@ async def test_inconsistent_tool_choice_and_tools(
514514
],
515515
tool_choice={},
516516
)
517+
518+
519+
@pytest.mark.asyncio
520+
async def test_max_tokens_with_tool_choice_required(client: openai.AsyncOpenAI):
521+
""" """
522+
models = await client.models.list()
523+
model_name: str = models.data[0].id
524+
525+
# This combination previously crashed the engine
526+
chat_completion = await client.chat.completions.create(
527+
messages=messages,
528+
temperature=0,
529+
max_completion_tokens=1,
530+
model=model_name,
531+
tools=tools,
532+
tool_choice="required",
533+
)
534+
# When `tool_choice="required"` and the tokens of `tools` exceed `max_tokens`,
535+
# both `tool_calls` and `content` should be empty.
536+
# This behavior should be consistent with OpenAI.
537+
choice = chat_completion.choices[0]
538+
assert choice.finish_reason == "length"
539+
assert len(choice.message.tool_calls) == 0
540+
assert choice.message.content == ""

vllm/entrypoints/openai/chat_completion/serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,7 @@ async def chat_completion_full_generator(
15071507

15081508
elif request.tool_choice and request.tool_choice == "required":
15091509
tool_call_class_items = []
1510-
assert tool_calls is not None and len(tool_calls) > 0
1510+
tool_calls = tool_calls or []
15111511
for idx, tool_call in enumerate(tool_calls):
15121512
# Use native ID if available,
15131513
# otherwise generate ID with correct id_type

vllm/entrypoints/openai/engine/serving.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import asyncio
4+
import contextlib
45
import json
56
import time
67
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
@@ -13,7 +14,7 @@
1314
from openai.types.responses import (
1415
ToolChoiceFunction,
1516
)
16-
from pydantic import ConfigDict, TypeAdapter
17+
from pydantic import ConfigDict, TypeAdapter, ValidationError
1718
from starlette.datastructures import Headers
1819

1920
import vllm.envs as envs
@@ -1125,17 +1126,19 @@ def _parse_tool_calls_from_content(
11251126
)
11261127
content = None # Clear content since tool is called.
11271128
elif request.tool_choice == "required":
1128-
assert content is not None
1129-
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
1130-
function_calls.extend(
1131-
[
1129+
tool_calls = []
1130+
with contextlib.suppress(ValidationError):
1131+
content = content or ""
1132+
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
1133+
content
1134+
)
1135+
for tool_call in tool_calls:
1136+
function_calls.append(
11321137
FunctionCall(
11331138
name=tool_call.name,
11341139
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
11351140
)
1136-
for tool_call in tool_calls
1137-
]
1138-
)
1141+
)
11391142
content = None # Clear content since tool is called.
11401143
elif (
11411144
tool_parser_cls

0 commit comments

Comments
 (0)