diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 25979d5502b0..b789acc26cde 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -13,7 +13,6 @@ import regex as re from fastapi import Request from openai_harmony import Message as OpenAIMessage -from pydantic import TypeAdapter from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( @@ -47,8 +46,6 @@ DeltaMessage, DeltaToolCall, ErrorResponse, - FunctionCall, - FunctionDefinition, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, @@ -1394,6 +1391,16 @@ async def chat_completion_full_generator( auto_tools_called = False # if auto tools are not enabled, and a named tool choice using # outlines is not being used + tool_calls, content = self._parse_tool_calls_from_content( + request=request, + tokenizer=tokenizer, + content=content, + enable_auto_tools=self.enable_auto_tools, + tool_parser_cls=self.tool_parser, + ) + tool_call_class = ( + MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall + ) if (not self.enable_auto_tools or not self.tool_parser) and ( not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) and request.tool_choice != "required" @@ -1407,63 +1414,33 @@ async def chat_completion_full_generator( request.tool_choice and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam ): - tool_call_class = ( - MistralToolCall - if isinstance(tokenizer, MistralTokenizer) - else ToolCall - ) + assert tool_calls is not None and len(tool_calls) > 0 message = ChatMessage( role=role, reasoning_content=reasoning_content, content="", - tool_calls=[ - tool_call_class( - function=FunctionCall( - name=request.tool_choice.function.name, - arguments=content, - ) - ) - ], + tool_calls=[tool_call_class(function=tc) for tc in tool_calls], ) elif request.tool_choice and request.tool_choice == "required": - tool_call_class = ( - MistralToolCall - if isinstance(tokenizer, MistralTokenizer) - else ToolCall - ) - - # the fields of FunctionDefinition are a superset of the - # tool call outputs and can be used for parsing - assert content is not None - tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json( - content - ) - tool_call_ids = [] + tool_call_class_items = [] + assert tool_calls is not None and len(tool_calls) > 0 for tool_call in tool_calls: - tool_call_ids.append( - make_tool_call_id( - id_type=self.tool_call_id_type, - func_name=tool_call.name, - idx=history_tool_call_cnt, + tool_call_class_items.append( + tool_call_class( + id=make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt, + ), + function=tool_call, ) ) history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", - tool_calls=[ - tool_call_class( - id=tool_call_ids[i], - function=FunctionCall( - name=tool_call.name, - arguments=json.dumps( - tool_call.parameters, ensure_ascii=False - ), - ), - ) - for i, tool_call in enumerate(tool_calls) - ], + tool_calls=tool_call_class_items, reasoning_content=reasoning_content, ) @@ -1481,25 +1458,22 @@ async def chat_completion_full_generator( and self.enable_auto_tools and self.tool_parser ): - try: - tool_parser = self.tool_parser(tokenizer) - except RuntimeError as e: - logger.exception("Error in tool parser creation.") - return self.create_error_response(str(e)) - - tool_call_info = tool_parser.extract_tool_calls( - content if content is not None else "", request=request - ) # In the OpenAI API the finish_reason is "tools_called" # if the tool choice is auto and the model produced a tool # call. The same is not true for named function calls - auto_tools_called = tool_call_info.tools_called - if tool_call_info.tools_called: + auto_tools_called = tool_calls is not None and len(tool_calls) > 0 + if tool_calls: message = ChatMessage( role=role, reasoning_content=reasoning_content, - content=tool_call_info.content, - tool_calls=tool_call_info.tool_calls, + content=content, + tool_calls=[ + ToolCall( + function=tc, + type="function", + ) + for tc in tool_calls + ], ) else: @@ -1509,8 +1483,8 @@ async def chat_completion_full_generator( # try to use content return from tool parser first, # tool parser may do some modify for the content. - if tool_call_info.content and len(tool_call_info.content) > 0: - ret_content = tool_call_info.content + if content and len(content) > 0: + ret_content = content message = ChatMessage( role=role, reasoning_content=reasoning_content, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 46e79edbde61..bafc0e2c372f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -12,7 +12,7 @@ import torch from fastapi import Request -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from starlette.datastructures import Headers from typing_extensions import TypeIs @@ -21,6 +21,10 @@ else: from typing_extensions import TypedDict +from openai.types.responses import ( + ToolChoiceFunction, +) + import vllm.envs as envs from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.engine.protocol import EngineClient @@ -36,6 +40,7 @@ from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ClassificationRequest, @@ -49,6 +54,8 @@ EmbeddingResponse, ErrorInfo, ErrorResponse, + FunctionCall, + FunctionDefinition, IOProcessorRequest, PoolingResponse, RerankRequest, @@ -1305,6 +1312,75 @@ def _get_data_parallel_rank(raw_request: Request | None) -> int | None: except ValueError: return None + @staticmethod + def _parse_tool_calls_from_content( + request: ResponsesRequest | ChatCompletionRequest, + tokenizer: AnyTokenizer, + enable_auto_tools: bool, + tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None, + content: str | None = None, + ) -> tuple[list[FunctionCall] | None, str | None]: + function_calls = list[FunctionCall]() + if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): + assert content is not None + # Forced Function Call + function_calls.append( + FunctionCall(name=request.tool_choice.name, arguments=content) + ) + content = None # Clear content since tool is called. + elif request.tool_choice and isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam + ): + assert content is not None + # Forced Function Call + function_calls.append( + FunctionCall(name=request.tool_choice.function.name, arguments=content) + ) + content = None # Clear content since tool is called. + elif request.tool_choice == "required": + assert content is not None + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) + function_calls.extend( + [ + FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + ) + for tool_call in tool_calls + ] + ) + content = None # Clear content since tool is called. + elif ( + tool_parser_cls + and enable_auto_tools + and (request.tool_choice == "auto" or request.tool_choice is None) + ): + # Automatic Tool Call Parsing + try: + tool_parser = tool_parser_cls(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + raise e + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", + request=request, # type: ignore + ) + if tool_call_info is not None and tool_call_info.tools_called: + # extract_tool_calls() returns a list of tool calls. + function_calls.extend( + FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_call_info.tool_calls + ) + content = tool_call_info.content + else: + # No tool calls. + return None, content + + return function_calls, content + @staticmethod def _get_decoded_token( logprob: Logprob,