Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,14 @@ async def create_chat_completion(
trace_headers=trace_headers,
)
else:
reasoning_ended = (
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
if not request.include_reasoning:
reasoning_ended = True
elif reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
prompt_token_ids or []
)
else:
reasoning_ended = None

generator = self.engine_client.generate(
engine_prompt,
Expand Down
53 changes: 27 additions & 26 deletions vllm/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
Tokenizer,
)
from mistral_common.tokens.tokenizers.instruct import (
InstructTokenizerBase,
InstructTokenizerV13,
)
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as MistralCommonTokenizer,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
Expand All @@ -26,21 +33,20 @@
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike

from .protocol import TokenizerLike
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)

if TYPE_CHECKING:
from transformers import BatchEncoding

try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)

logger = init_logger(__name__)


Expand Down Expand Up @@ -235,15 +241,6 @@ def from_pretrained(
download_dir: str | None = None,
**kwargs,
) -> "MistralTokenizer":
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)

tokenizer = MistralCommonBackend.from_pretrained(
path_or_repo_id,
*args,
Expand All @@ -255,13 +252,13 @@ def from_pretrained(

return cls(tokenizer)

def __init__(self, tokenizer: "MistralCommonBackend") -> None:
def __init__(self, tokenizer: MistralCommonBackend) -> None:
super().__init__()

self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
self.tokenizer = self.instruct.tokenizer
self.transformers_tokenizer: MistralCommonBackend = tokenizer
self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
self.tokenizer: Tokenizer = self.instruct.tokenizer

mode = self.mistral._chat_completion_request_validator._mode
if mode != ValidationMode.test:
Expand Down Expand Up @@ -483,7 +480,11 @@ def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)

def convert_tokens_to_string(self, tokens: list[str]) -> str:
to_decode_special_tokens = {SpecialTokens.tool_calls}
to_decode_special_tokens = {
SpecialTokens.tool_calls,
SpecialTokens.begin_think,
SpecialTokens.end_think,
}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [
Expand Down
10 changes: 7 additions & 3 deletions vllm/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ def extract_tool_calls_streaming(
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if self.bot_token_id not in current_token_ids:
has_bot_token = (
self.bot_token_id in current_token_ids or self.bot_token in current_text
)
if not has_bot_token:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return DeltaMessage(content=delta_text)
Expand Down Expand Up @@ -275,7 +278,8 @@ def _extract_tool_calls_streaming(
additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call
assert self.bot_token_id in delta_token_ids
if self.bot_token not in delta_text:
return DeltaMessage(content=delta_text)
if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join(
Expand Down Expand Up @@ -411,7 +415,7 @@ def _extract_tool_calls_streaming_pre_v11_tokenizer(
index=self.current_tool_id, type="function"
)
current_tool_call_modified = False
if self.bot_token_id in delta_token_ids:
if self.bot_token_id in delta_token_ids or self.bot_token in delta_text:
# this is the first tool call
if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0]
Expand Down
Loading