Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.utils import generate_valid_mistral_tool_id

logger = init_logger(__name__)

Expand Down Expand Up @@ -668,6 +669,10 @@ async def chat_completion_full_generator(
arguments=output.text))
])

if isinstance(tokenizer, MistralTokenizer):
for tool_call in message.tool_calls:
tool_call.id = generate_valid_mistral_tool_id()

# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class MistralToolCall(ToolCall):

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))

Expand Down
4 changes: 3 additions & 1 deletion vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Tekkenizer)

from vllm.logger import init_logger
from vllm.utils import is_list_of
from vllm.utils import generate_valid_mistral_tool_id, is_list_of

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
Expand Down Expand Up @@ -62,6 +62,8 @@ def maybe_serialize_tool_calls(request: ChatCompletionRequest):
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
if not re.match(r"^[a-zA-Z0-9]{9}$", tool_call['id']):
tool_call['id'] = generate_valid_mistral_tool_id()
except StopIteration:
break

Expand Down
10 changes: 10 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from collections.abc import Hashable, Iterable, Mapping
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from random import choices
from string import ascii_letters, digits
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union,
Expand Down Expand Up @@ -57,6 +59,8 @@

logger = init_logger(__name__)

ALPHANUMERIC = ascii_letters + digits

# Exception strings for non-implemented encoder/decoder scenarios

# Reminder: Please update docs/source/features/compatibility_matrix.md
Expand Down Expand Up @@ -2206,3 +2210,9 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)


def generate_valid_mistral_tool_id():
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))