Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/tool_use/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ServerConfig(TypedDict):
CONFIGS: Dict[str, ServerConfig] = {
"hermes": {
"model":
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"NousResearch/Hermes-3-Llama-3.1-8B",
"arguments": [
"--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
Expand Down
7 changes: 0 additions & 7 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,13 +713,6 @@ class DeltaToolCall(OpenAIBaseModel):
function: Optional[DeltaFunctionCall] = None


# the initial delta that gets sent once a new tool call is started;
class InitialDeltaToolCall(DeltaToolCall):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
index: int


class ExtractedToolCallInformation(BaseModel):
# indicate if tools were called
tools_called: bool
Expand Down
6 changes: 5 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,13 @@ async def chat_completion_stream_generator(
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):

choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
delta=DeltaMessage(
role=role,
content="",
),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self, tokenizer: AnyTokenizer):
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = []

self.model_tokenizer = tokenizer
Expand Down
20 changes: 5 additions & 15 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
InitialDeltaToolCall, ToolCall)
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand All @@ -34,7 +34,6 @@ def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list

Expand Down Expand Up @@ -168,7 +167,6 @@ def extract_tool_calls_streaming(
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.current_tool_initial_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)

Expand Down Expand Up @@ -218,24 +216,16 @@ def extract_tool_calls_streaming(
logger.debug('not enough tokens to parse into JSON yet')
return None

# case - we haven't sent the initial delta with the tool call ID
# (it will be sent)
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
return DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])

# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
elif not self.current_tool_name_sent:
if not self.current_tool_name_sent:
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
27 changes: 8 additions & 19 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
InitialDeltaToolCall, ToolCall)
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand All @@ -25,7 +25,7 @@ class MistralToolParser(ToolParser):
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.

Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""

def __init__(self, tokenizer: AnyTokenizer):
Expand All @@ -42,7 +42,6 @@ def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
Expand Down Expand Up @@ -91,7 +90,6 @@ def extract_tool_calls(self,

except Exception as e:
logger.error("Error in extracting tool call from response: %s", e)
print("ERROR", e)
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
Expand All @@ -109,7 +107,7 @@ def extract_tool_calls_streaming(

# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token_id not in current_token_ids:
if self.bot_token not in current_text:
return DeltaMessage(content=delta_text)

# if the tool call token ID IS in the tokens generated so far, that
Expand All @@ -134,7 +132,7 @@ def extract_tool_calls_streaming(
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[1]
parsable_arr = current_text.split(self.bot_token)[-1]

# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
Expand Down Expand Up @@ -186,31 +184,22 @@ def extract_tool_calls_streaming(
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.current_tool_initial_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta

# case: update an existing tool - this is handled below

# if the current tool initial data incl. the id, type=function
# and idx not sent, send that
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
delta = DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])

# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:

delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down