From 87b63520aba48bae9d55663edf0d8f65d3d43ac2 Mon Sep 17 00:00:00 2001 From: sydnash Date: Thu, 12 Sep 2024 16:07:05 +0800 Subject: [PATCH 01/19] [add] add tools call for internlm2 --- .../tool_chat_template_internlm2_tool.jinja | 60 +++++++++ tests/tool_use/test_parallel_tool_calls.py | 12 +- tests/tool_use/test_tool_calls.py | 6 +- tests/tool_use/utils.py | 23 +++- vllm/entrypoints/openai/api_server.py | 4 + vllm/entrypoints/openai/cli_args.py | 10 +- vllm/entrypoints/openai/serving_chat.py | 32 ++--- .../openai/tool_parsers/__init__.py | 8 +- .../tool_parsers/abstract_tool_parser.py | 118 +++++++++++++++++- .../openai/tool_parsers/hermes_tool_parser.py | 14 ++- .../tool_parsers/internlm2_tool_parser.py | 118 ++++++++++++++++++ .../tool_parsers/mistral_tool_parser.py | 14 ++- 12 files changed, 382 insertions(+), 37 deletions(-) create mode 100644 examples/tool_chat_template_internlm2_tool.jinja create mode 100644 vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py diff --git a/examples/tool_chat_template_internlm2_tool.jinja b/examples/tool_chat_template_internlm2_tool.jinja new file mode 100644 index 000000000000..744cd7027f42 --- /dev/null +++ b/examples/tool_chat_template_internlm2_tool.jinja @@ -0,0 +1,60 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{{- bos_token }} +{%- if system_message is defined %} +{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }} +{%- endif %} + +{%- if tools is not none %} + {{- "<|im_start|>system name=<|plugin|>\n[" }} + {%- for tool in tools %} + {{- tool.function|tojson }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "<|im_end|>\n" }} +{%- endif %} + +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message.tool_calls is defined and message.tool_calls is not none %} + {%- set content = message["content"] if message["content"] else "" %} + {{- "<|im_start|>assistant\n" + content }} + {%- for tool_call in message.tool_calls %} + {%- set function=tool_call.function %} + {{- "<|action_start|><|plugin|>\n" }} + {{- '{"name": "' + function.name + '", '}} + {{- '"arguments": ' + function.arguments|string + '}' }} + {{- "<|action_end|>" }} + {%- endfor %} + {{- "<|im_end|>\n" }} + {%- elif message["role"] == "assistant" %} + {{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }} + {%- else %} + {{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} +{{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index b03b5a2075a6..d334bfcaecea 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -6,7 +6,7 @@ from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, - WEATHER_TOOL) + WEATHER_TOOL, ServerConfig) # test: getting the model to generate parallel tool calls (streaming/not) @@ -14,9 +14,12 @@ # may be added in the future. e.g. llama 3.1 models are not designed to support # parallel tool calls. @pytest.mark.asyncio -async def test_parallel_tool_calls(client: openai.AsyncOpenAI): +async def test_parallel_tool_calls(client: openai.AsyncOpenAI, + server_config: ServerConfig): models = await client.models.list() model_name: str = models.data[0].id + if server_config.get("skip_parallel", False): + pytest.skip(f"skip parallel test for {model_name}") chat_completion = await client.chat.completions.create( messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, temperature=0, @@ -136,9 +139,12 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): # test: providing parallel tool calls back to the model to get a response # (streaming/not) @pytest.mark.asyncio -async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): +async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, + server_config: ServerConfig): models = await client.models.list() model_name: str = models.data[0].id + if server_config.get("skip_parallel", False): + pytest.skip(f"skip parallel test for {model_name}") chat_completion = await client.chat.completions.create( messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, temperature=0, diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index c3abe9e1f506..89303b12e1da 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -46,7 +46,8 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert isinstance(parsed_arguments.get("city"), str) assert isinstance(parsed_arguments.get("state"), str) assert parsed_arguments.get("city") == "Dallas" - assert parsed_arguments.get("state") == "TX" + assert parsed_arguments.get("state") == "TX" or parsed_arguments.get( + "state") == "Texas" assert stop_reason == "tool_calls" @@ -119,7 +120,8 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert isinstance(streamed_args.get("city"), str) assert isinstance(streamed_args.get("state"), str) assert streamed_args.get("city") == "Dallas" - assert streamed_args.get("state") == "TX" + assert parsed_arguments.get("state") == "TX" or parsed_arguments.get( + "state") == "Texas" # make sure everything matches non-streaming except for ID assert function_name == tool_calls[0].function.name diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index e447469e3341..dea7775bef50 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional from openai.types.chat import (ChatCompletionMessageParam, ChatCompletionToolParam) @@ -10,6 +10,7 @@ class ServerConfig(TypedDict): model: str arguments: List[str] + skip_parallel: Optional[bool] # universal args for all models go here. also good if you need to test locally @@ -23,7 +24,9 @@ class ServerConfig(TypedDict): "arguments": [ "--tool-call-parser", "hermes", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") - ] + ], + "skip_parallel": + False }, "mistral": { "model": @@ -32,7 +35,21 @@ class ServerConfig(TypedDict): "--tool-call-parser", "mistral", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), "--ignore-patterns=\"consolidated.safetensors\"" - ] + ], + "skip_parallel": + False + }, + "internlm2_5": { + "model": + "internlm/internlm2_5-7b-chat", + "arguments": [ + "--tool-call-parser", "internlm2_5", "--chat-template", + str(VLLM_PATH / + "examples/tool_chat_template_internlm2_tool.jinja"), + "--trust_remote_code" + ], + "skip_parallel": + True } } diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d8704d5e2496..bb5112dda645 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -47,6 +47,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path @@ -446,6 +447,9 @@ async def init_app( else: request_logger = RequestLogger(max_log_len=args.max_log_len) + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + global openai_serving_chat global openai_serving_completion global openai_serving_embedding diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 7ccee0b6b55b..b37d40359907 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -174,13 +174,21 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--tool-call-parser", type=str, - choices=["mistral", "hermes"], default=None, help= "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " "format. Required for --enable-auto-tool-choice.") + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in --tool-call-parser.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8ed81e9c88cb..65bd2e7cc692 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -26,9 +26,7 @@ OpenAIServing, PromptAdapterPath, TextTokensPrompt) -from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, - MistralToolParser, - ToolParser) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput @@ -78,13 +76,13 @@ def __init__(self, self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: - if tool_parser == "mistral": - self.tool_parser = MistralToolParser - elif tool_parser == "hermes": - self.tool_parser = Hermes2ProToolParser - else: + try: + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: raise TypeError("Error: --enable-auto-tool-choice requires " - "--tool-call-parser") + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e async def create_chat_completion( self, @@ -156,6 +154,9 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: + if self.enable_auto_tools and self.tool_parser: + request = self.tool_parser(tokenizer).adjust_request( + request=request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -255,8 +256,9 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices - tool_parser: Optional[ToolParser] = self.tool_parser( - tokenizer) if self.tool_parser else None + tool_parsers: List[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) if self.tool_parser else None + ] * num_choices try: async for res in result_generator: @@ -271,7 +273,7 @@ async def chat_completion_stream_generator( # NOTE num_choices defaults to 1 so this usually executes # once per request for i in range(num_choices): - + tool_parser = tool_parsers[i] choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( @@ -396,7 +398,8 @@ async def chat_completion_stream_generator( :-1 * len(delta_token_ids) ], current_token_ids=output.token_ids, - delta_token_ids=delta_token_ids + delta_token_ids=delta_token_ids, + request = request, ) ) @@ -628,7 +631,8 @@ async def chat_completion_full_generator( and self.tool_parser: tool_parser = self.tool_parser(tokenizer) - tool_call_info = tool_parser.extract_tool_calls(output.text) + tool_call_info = tool_parser.extract_tool_calls( + output.text, request=request) tools_called = tool_call_info.tools_called if tool_call_info.tools_called: message = ChatMessage(role=role, diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 5d5d53784fed..cc86b45a2c4a 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,5 +1,9 @@ -from .abstract_tool_parser import ToolParser +from .abstract_tool_parser import ToolParser, ToolParserManager from .hermes_tool_parser import Hermes2ProToolParser +from .internlm2_tool_parser import Internlm2ToolParser from .mistral_tool_parser import MistralToolParser -__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"] \ No newline at end of file +__all__ = [ + "ToolParser", "ToolParserManager", "Hermes2ProToolParser", + "MistralToolParser", "Internlm2ToolParser" +] diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 873f615d4325..60dddcdaa20f 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,6 +1,11 @@ -from typing import Dict, List, Sequence, Union +import importlib +import importlib.util +import os +from collections import abc +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union -from vllm.entrypoints.openai.protocol import (DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, ExtractedToolCallInformation) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -24,8 +29,16 @@ def __init__(self, tokenizer: AnyTokenizer): self.model_tokenizer = tokenizer - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Static method that should be implemented for extracting tool calls from a complete model-generated string. @@ -44,6 +57,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Instance method that should be implemented for extracting tool calls @@ -55,3 +69,99 @@ def extract_tool_calls_streaming( raise NotImplementedError( "AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!") + + +def is_seq_of(seq: Any, + expected_type: Union[Type, tuple], + seq_type: Union[Type, None] = None) -> bool: + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + + return all(isinstance(item, expected_type) for item in seq) + + +class ToolParserManager: + tool_parsers: Dict[str, Type] = {} + + @classmethod + def get_tool_parser(cls, name) -> Type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ToolParser): + raise TypeError( + f'module must be subclass of ToolParser, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f'{name} is already registered ' + f'at {existed_module.__module__}') + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + spec = importlib.util.spec_from_file_location(module_name, plugin_path) + if spec is None or spec.loader is None: + logger.error("load %s from %s failed.", module_name, plugin_path) + return + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index bde9b47ce60d..2e50b5ab7169 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -5,12 +5,13 @@ import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -20,6 +21,7 @@ logger = init_logger(__name__) +@ToolParserManager.register_module("hermes") class Hermes2ProToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): @@ -58,8 +60,11 @@ def __init__(self, tokenizer: AnyTokenizer): "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_call_start_token not in model_output: @@ -115,6 +120,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: logger.debug("delta_text: %s", delta_text) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py new file mode 100644 index 000000000000..11379d4ffb93 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -0,0 +1,118 @@ +import json +from typing import Sequence, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["internlm2", "internlm2_5"]) +class Internlm2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + self.current_tool_id = 0 + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + request.skip_special_tokens = False + return request + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if '<|action_start|>' not in current_text: + self.position = len(current_text) + return DeltaMessage(content=delta_text) + + if self.current_tool_id > 0: + return DeltaMessage(content='') + + last_pos = self.position + if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + return None + + new_delta = current_text[last_pos:] + text, action = new_delta.split('<|action_start|><|plugin|>') + if '<|action_end|>' not in action: + self.position = last_pos + len(text) + return None if len(text) == 0 else DeltaMessage(content=text) + + action = action.split('<|action_end|>'.strip())[0] + + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', {}))) + + last_pos = current_text[last_pos:].find("<|action_end|>") + len( + '<|action_end|>') + self.position = last_pos + if not request.tools or name not in [ + t.function.name for t in request.tools + ]: + return None if len(text) == 0 else DeltaMessage(content=text) + + delta = DeltaMessage(content=text, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=name, arguments=parameters)), + ]) + self.current_tool_id = self.current_tool_id + 1 + self.prev_tool_call_arr = [action_dict] + return delta + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + tools = request.tools + if '<|action_start|><|plugin|>' in text: + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', + {}))) + + if not tools or name not in [t.function.name for t in tools]: + ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + tool_calls = [ + ToolCall( + function=FunctionCall(name=name, arguments=parameters)) + ] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=text if len(text) > 0 else None) + + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 4b0e1c91df97..c577c26ffeee 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -5,12 +5,13 @@ import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -20,6 +21,7 @@ logger = init_logger(__name__) +@ToolParserManager.register_module("mistral") class MistralToolParser(ToolParser): """ Tool call parser for Mistral 7B Instruct v0.3, intended for use with the @@ -48,8 +50,11 @@ def __init__(self, tokenizer: AnyTokenizer): self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. Requires find-and-replacing single quotes with double quotes for JSON parsing, @@ -103,6 +108,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: # if the tool call token is not in the tokens generated so far, append From 68cd89d375fe5a3c6d64dd9f71b21a9ee8ba6931 Mon Sep 17 00:00:00 2001 From: sydnash Date: Thu, 12 Sep 2024 16:44:45 +0800 Subject: [PATCH 02/19] [add] add some comments --- .../entrypoints/openai/tool_parsers/internlm2_tool_parser.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 11379d4ffb93..adf0c9f0ea20 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -26,6 +26,8 @@ def __init__(self, tokenizer: AnyTokenizer): def adjust_request( self, request: ChatCompletionRequest) -> ChatCompletionRequest: if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls infomation. request.skip_special_tokens = False return request @@ -42,7 +44,8 @@ def extract_tool_calls_streaming( if '<|action_start|>' not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) - + # if the tool call is sended, return a emptry delta message + # to make sure the finish_reason will be send correctly. if self.current_tool_id > 0: return DeltaMessage(content='') From d17f006d3b6c5074baebb5d537460e3cea6c2405 Mon Sep 17 00:00:00 2001 From: sydnash Date: Thu, 12 Sep 2024 16:59:26 +0800 Subject: [PATCH 03/19] [add] add some comments --- .../entrypoints/openai/tool_parsers/internlm2_tool_parser.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index adf0c9f0ea20..8b685ab00146 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -27,7 +27,8 @@ def adjust_request( self, request: ChatCompletionRequest) -> ChatCompletionRequest: if request.tools and request.tool_choice != 'none': # do not skip special tokens because internlm use the special - # tokens to indicated the start and end of the tool calls infomation. + # tokens to indicated the start and end of the tool calls + # information. request.skip_special_tokens = False return request @@ -44,7 +45,7 @@ def extract_tool_calls_streaming( if '<|action_start|>' not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) - # if the tool call is sended, return a emptry delta message + # if the tool call is sended, return a empty delta message # to make sure the finish_reason will be send correctly. if self.current_tool_id > 0: return DeltaMessage(content='') From 2d7d9d45dd6d9ab93a7c83e4b6106396ea077877 Mon Sep 17 00:00:00 2001 From: sydnash Date: Fri, 13 Sep 2024 10:23:56 +0800 Subject: [PATCH 04/19] [fix] fix internlm2 tool chat template, fix the internlm2 tool call of Texas, fix the lake of type='function' in the delta tool call message --- examples/tool_chat_template_internlm2_tool.jinja | 2 +- tests/tool_use/test_tool_calls.py | 6 ++---- tests/tool_use/utils.py | 2 +- .../openai/tool_parsers/internlm2_tool_parser.py | 1 + 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/tool_chat_template_internlm2_tool.jinja b/examples/tool_chat_template_internlm2_tool.jinja index 744cd7027f42..ac99666e93bc 100644 --- a/examples/tool_chat_template_internlm2_tool.jinja +++ b/examples/tool_chat_template_internlm2_tool.jinja @@ -37,7 +37,7 @@ {%- set function=tool_call.function %} {{- "<|action_start|><|plugin|>\n" }} {{- '{"name": "' + function.name + '", '}} - {{- '"arguments": ' + function.arguments|string + '}' }} + {{- '"arguments": ' + function.arguments|tojson + '}' }} {{- "<|action_end|>" }} {%- endfor %} {{- "<|im_end|>\n" }} diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 89303b12e1da..c2c73e8bca98 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -46,8 +46,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert isinstance(parsed_arguments.get("city"), str) assert isinstance(parsed_arguments.get("state"), str) assert parsed_arguments.get("city") == "Dallas" - assert parsed_arguments.get("state") == "TX" or parsed_arguments.get( - "state") == "Texas" + assert parsed_arguments.get("state") == "TX" assert stop_reason == "tool_calls" @@ -120,8 +119,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert isinstance(streamed_args.get("city"), str) assert isinstance(streamed_args.get("state"), str) assert streamed_args.get("city") == "Dallas" - assert parsed_arguments.get("state") == "TX" or parsed_arguments.get( - "state") == "Texas" + assert parsed_arguments.get("state") == "TX" # make sure everything matches non-streaming except for ID assert function_name == tool_calls[0].function.name diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index dea7775bef50..44baf6ba99d3 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -72,7 +72,7 @@ class ServerConfig(TypedDict): "type": "string", "description": - "the two-letter abbreviation for the state " + "must the two-letter abbreviation for the state " "that the city is in, e.g. 'CA' which would " "mean 'California'" }, diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 8b685ab00146..6e5cc5e63b3f 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -80,6 +80,7 @@ def extract_tool_calls_streaming( DeltaToolCall( index=self.current_tool_id, id=f"chatcmpl-tool-{random_uuid()}", + type="function", function=DeltaFunctionCall( name=name, arguments=parameters)), ]) From 12352e7d001f6d282fb81fb3d29fc64635b391e7 Mon Sep 17 00:00:00 2001 From: sydnash Date: Fri, 13 Sep 2024 11:06:44 +0800 Subject: [PATCH 05/19] [add] add tool parser plugin doc --- .../serving/openai_compatible_server.md | 73 ++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index eb4ea0fb5655..d003813869bc 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers -will continue to be added in the future. +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `internlm2_5` or `internlm2`. Additional tool parsers +will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. +* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their `tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat @@ -197,3 +198,71 @@ when tools are provided, that results in much better reliability when working wi Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` + +#### Internlm Models +Supported models: +* `internlm/internlm2_5-7b-chat` (confirmed) +* Additional internlm2.5 function-calling models are compatible as well + +Known issues: +* Although this implementation also supports Internlm2, the tool call results are not ideal when testing with the `internlm/internlm2-chat-7b` model. + + +### How to write a tool parser plugin + +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. + +Here is a summary of a plugin file: + +```python + +# import the required packages + +# define a tool parser and register it to vllm +# the name list in register_module can be used +# in --tool-call-parser. you can define as many +# tool parsers as you want here. +@ToolParserManager.register_module(["example"]) +class ExampleToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # adjust request. e.g.: set skip special tokens + # to False for tool call output. + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + return request + + # implement the tool call parse for stream call + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + return delta + + # implement the tool parse for non-stream call + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + +``` + +and then you can the flags to specify the plugins: +``` +--enable-auto-tool-choice \ +--tool-parser-plugin +--tool-call-parser example \ +--chat-template \ +``` \ No newline at end of file From 11bed0d8499bb3cc8e6a45a2be748e57ec46bacb Mon Sep 17 00:00:00 2001 From: sydnash Date: Fri, 13 Sep 2024 11:10:16 +0800 Subject: [PATCH 06/19] [add] add tool parser plugin doc --- docs/source/serving/openai_compatible_server.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index d003813869bc..9573525df0d9 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -259,10 +259,10 @@ class ExampleToolParser(ToolParser): ``` -and then you can the flags to specify the plugins: +Then you can use this plugin in the command line like this. ``` ---enable-auto-tool-choice \ ---tool-parser-plugin ---tool-call-parser example \ ---chat-template \ + --enable-auto-tool-choice \ + --tool-parser-plugin + --tool-call-parser example \ + --chat-template \ ``` \ No newline at end of file From 8a8b840148696982515be8ab16dde4abf1e5bc8f Mon Sep 17 00:00:00 2001 From: sydnash Date: Fri, 13 Sep 2024 15:11:56 +0800 Subject: [PATCH 07/19] [fix] fix the stream tool call for internlm2 --- .../tool_parsers/internlm2_tool_parser.py | 137 ++++++++++++++---- 1 file changed, 107 insertions(+), 30 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 6e5cc5e63b3f..70ff80c59696 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -1,5 +1,8 @@ import json -from typing import Sequence, Union +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, @@ -8,6 +11,8 @@ FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -21,7 +26,6 @@ class Internlm2ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.position = 0 - self.current_tool_id = 0 def adjust_request( self, request: ChatCompletionRequest) -> ChatCompletionRequest: @@ -56,37 +60,110 @@ def extract_tool_calls_streaming( new_delta = current_text[last_pos:] text, action = new_delta.split('<|action_start|><|plugin|>') - if '<|action_end|>' not in action: - self.position = last_pos + len(text) - return None if len(text) == 0 else DeltaMessage(content=text) + if len(text) > 0: + self.position = self.position + len(text) + return DeltaMessage(content=text) + + action = action.strip() action = action.split('<|action_end|>'.strip())[0] - action = action[action.find('{'):] - action_dict = json.loads(action) - name, parameters = action_dict['name'], json.dumps( - action_dict.get('parameters', action_dict.get('arguments', {}))) - - last_pos = current_text[last_pos:].find("<|action_end|>") + len( - '<|action_end|>') - self.position = last_pos - if not request.tools or name not in [ - t.function.name for t in request.tools - ]: - return None if len(text) == 0 else DeltaMessage(content=text) - - delta = DeltaMessage(content=text, - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - id=f"chatcmpl-tool-{random_uuid()}", - type="function", - function=DeltaFunctionCall( - name=name, arguments=parameters)), - ]) - self.current_tool_id = self.current_tool_id + 1 - self.prev_tool_call_arr = [action_dict] - return delta + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_arr = action + + # tool calls are generated in an object in inernlm2 + try: + tool_call_arr: Dict = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = tool_call_arr.get("name") + if function_name: + self.current_tool_id = self.current_tool_id + 1 + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + self.streamed_args_for_tool.append("") + else: + delta = None + # now we know we're on the same tool call and we're streaming + # arguments + else: + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("parameters") + cur_arguments = tool_call_arr.get("parameters") + + # not arguments generated + if not cur_arguments and not prev_arguments: + delta = None + # will never happen + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + # first time to get parameters + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(delta_text) + + len(delta_text)] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + # both prev and cur parameters, send the increase parameters + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + tool_call_arr["arguments"] = tool_call_arr.get("parameters") + self.prev_tool_call_arr = [tool_call_arr] + return delta + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None def extract_tool_calls( self, From 00c5da26a4089ec98a41a1b5145b46df21bcb0fd Mon Sep 17 00:00:00 2001 From: sydnash Date: Fri, 13 Sep 2024 15:17:47 +0800 Subject: [PATCH 08/19] [fix] comment --- vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 70ff80c59696..ea6cdbf0a989 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -79,6 +79,7 @@ def extract_tool_calls_streaming( parsable_arr = action # tool calls are generated in an object in inernlm2 + # it's not support parallel tool calls try: tool_call_arr: Dict = partial_json_parser.loads( parsable_arr, flags) From 12b10359a26332c882d2ec6bed77bb1ae5cee832 Mon Sep 17 00:00:00 2001 From: sydnash Date: Sat, 14 Sep 2024 10:56:53 +0800 Subject: [PATCH 09/19] [fix] use metavar to display the help info for --tool-call-parser, add quick check for the validation of --tool-call-parser --- vllm/entrypoints/openai/api_server.py | 12 +++++++++--- vllm/entrypoints/openai/cli_args.py | 3 +++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bb5112dda645..aac4505623e4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -447,9 +447,6 @@ async def init_app( else: request_logger = RequestLogger(max_log_len=args.max_log_len) - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - global openai_serving_chat global openai_serving_completion global openai_serving_embedding @@ -499,6 +496,15 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valide_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valide_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valide_tool_parses)} }})") + async with build_async_engine_client(args) as async_engine_client: # If None, creation of the client failed and we exit. if async_engine_client is None: diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index b37d40359907..ac60bca48539 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -171,9 +171,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Enable auto tool choice for supported models. Use --tool-call-parser" "to specify which parser to use") + valid_tool_parsers = ["mistral", "hermes", "internlm2", "internlm2_5"] parser.add_argument( "--tool-call-parser", type=str, + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", default=None, help= "Select the tool call parser depending on the model that you're using." From ed5b3fd4acfa170af5e8ca38a18c30146f276e62 Mon Sep 17 00:00:00 2001 From: sydnash Date: Sun, 15 Sep 2024 07:19:26 +0800 Subject: [PATCH 10/19] [add] got valid tool parsers from ToolParserManager --- vllm/entrypoints/openai/cli_args.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index ac60bca48539..b25cbc149f63 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -12,6 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.utils import FlexibleArgumentParser @@ -171,7 +172,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Enable auto tool choice for supported models. Use --tool-call-parser" "to specify which parser to use") - valid_tool_parsers = ["mistral", "hermes", "internlm2", "internlm2_5"] + valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( "--tool-call-parser", type=str, From ea2c0898a6881a6bd66af4f2855a336c18e5883a Mon Sep 17 00:00:00 2001 From: sydnash Date: Sun, 15 Sep 2024 07:41:23 +0800 Subject: [PATCH 11/19] [fix] fix build for docs --- docs/requirements-docs.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 6687929c0beb..80037dda2001 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -12,4 +12,5 @@ torch py-cpuinfo transformers mistral_common >= 1.3.4 -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file From 36ad5d02150677b4e20cc152f2318865551b2ef1 Mon Sep 17 00:00:00 2001 From: sydnash Date: Sun, 15 Sep 2024 08:48:11 +0800 Subject: [PATCH 12/19] [fix] internlm's tool call out may arguments or parameters --- .../openai/tool_parsers/internlm2_tool_parser.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index ea6cdbf0a989..44e98c705400 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -36,6 +36,13 @@ def adjust_request( request.skip_special_tokens = False return request + def get_argments(self, obj): + if "parameters" in obj: + return obj.get("parameters") + elif "arguments" in obj: + return obj.get("arguments") + return None + def extract_tool_calls_streaming( self, previous_text: str, @@ -108,9 +115,9 @@ def extract_tool_calls_streaming( # now we know we're on the same tool call and we're streaming # arguments else: - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("parameters") - cur_arguments = tool_call_arr.get("parameters") + prev_arguments = self.get_argments( + self.prev_tool_call_arr[self.current_tool_id]) + cur_arguments = self.get_argments(tool_call_arr) # not arguments generated if not cur_arguments and not prev_arguments: @@ -156,7 +163,7 @@ def extract_tool_calls_streaming( # check to see if the name is defined and has been sent. if so, # stream the name - otherwise keep waiting # finish by setting old and returning None as base case - tool_call_arr["arguments"] = tool_call_arr.get("parameters") + tool_call_arr["arguments"] = self.get_argments(tool_call_arr) self.prev_tool_call_arr = [tool_call_arr] return delta except Exception as e: From 647db0dcf7d3945f6f2691d51c103109705f0ea8 Mon Sep 17 00:00:00 2001 From: sydnash Date: Thu, 26 Sep 2024 09:30:30 +0800 Subject: [PATCH 13/19] refactor the tool parser to internlm, fix the test case of streamed_args --- docs/source/serving/openai_compatible_server.md | 6 ++++-- tests/tool_use/test_tool_calls.py | 2 +- tests/tool_use/utils.py | 4 ++-- .../openai/tool_parsers/internlm2_tool_parser.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 9573525df0d9..8d8b87ee8dac 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,7 +157,7 @@ vLLM will use guided decoding to ensure the response matches the tool parameter To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `internlm2_5` or `internlm2`. Additional tool parsers +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `internlm`. Additional tool parsers will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. * `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages @@ -205,7 +205,9 @@ Supported models: * Additional internlm2.5 function-calling models are compatible as well Known issues: -* Although this implementation also supports Internlm2, the tool call results are not ideal when testing with the `internlm/internlm2-chat-7b` model. +* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model. + +Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja` ### How to write a tool parser plugin diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index c2c73e8bca98..c3abe9e1f506 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -119,7 +119,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert isinstance(streamed_args.get("city"), str) assert isinstance(streamed_args.get("state"), str) assert streamed_args.get("city") == "Dallas" - assert parsed_arguments.get("state") == "TX" + assert streamed_args.get("state") == "TX" # make sure everything matches non-streaming except for ID assert function_name == tool_calls[0].function.name diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 44baf6ba99d3..fac8a94e9760 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -39,11 +39,11 @@ class ServerConfig(TypedDict): "skip_parallel": False }, - "internlm2_5": { + "internlm": { "model": "internlm/internlm2_5-7b-chat", "arguments": [ - "--tool-call-parser", "internlm2_5", "--chat-template", + "--tool-call-parser", "internlm", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_internlm2_tool.jinja"), "--trust_remote_code" diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 44e98c705400..905ab7db3d04 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -20,7 +20,7 @@ logger = init_logger(__name__) -@ToolParserManager.register_module(["internlm2", "internlm2_5"]) +@ToolParserManager.register_module(["internlm"]) class Internlm2ToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): From 106909c737b1ebeddb6e232e7b5d63ddec33012c Mon Sep 17 00:00:00 2001 From: sydnash Date: Sat, 28 Sep 2024 09:42:35 +0800 Subject: [PATCH 14/19] [fix] fix internlm parallel test, remove vllm/version.py --- tests/tool_use/utils.py | 2 +- .../tool_parsers/abstract_tool_parser.py | 18 ++---------------- vllm/version.py | 11 ----------- 3 files changed, 3 insertions(+), 28 deletions(-) delete mode 100644 vllm/version.py diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 26cc031fc0e9..44076197b8c2 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -92,7 +92,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "--trust_remote_code" ], "supports_parallel": - True, + False, "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 60dddcdaa20f..b7c8b560fce2 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -8,6 +8,7 @@ DeltaMessage, ExtractedToolCallInformation) from vllm.logger import init_logger +from vllm.utils import is_list_of from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -70,21 +71,6 @@ def extract_tool_calls_streaming( "AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!") - -def is_seq_of(seq: Any, - expected_type: Union[Type, tuple], - seq_type: Union[Type, None] = None) -> bool: - if seq_type is None: - exp_seq_type = abc.Sequence - else: - assert isinstance(seq_type, type) - exp_seq_type = seq_type - if not isinstance(seq, exp_seq_type): - return False - - return all(isinstance(item, expected_type) for item in seq) - - class ToolParserManager: tool_parsers: Dict[str, Type] = {} @@ -135,7 +121,7 @@ def register_module( raise TypeError(f'force must be a boolean, but got {type(force)}') # raise the error ahead of time - if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + if not (name is None or isinstance(name, str) or is_list_of(name, str)): raise TypeError( 'name must be None, an instance of str, or a sequence of str, ' f'but got {type(name)}') diff --git a/vllm/version.py b/vllm/version.py deleted file mode 100644 index 66e189dcedf7..000000000000 --- a/vllm/version.py +++ /dev/null @@ -1,11 +0,0 @@ -try: - from ._version import __version__, __version_tuple__ -except Exception as e: - import warnings - - warnings.warn(f"Failed to read commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - - __version__ = "dev" - __version_tuple__ = (0, 0, __version__) From e24250169d34b0d8e5dbc1e8637f0d573fc609e9 Mon Sep 17 00:00:00 2001 From: sydnash Date: Sun, 29 Sep 2024 08:03:38 +0800 Subject: [PATCH 15/19] [format] --- .../entrypoints/openai/tool_parsers/abstract_tool_parser.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index b7c8b560fce2..594024185d98 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -2,7 +2,7 @@ import importlib.util import os from collections import abc -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +from typing import Callable, Dict, List, Optional, Sequence, Type, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, @@ -71,6 +71,7 @@ def extract_tool_calls_streaming( "AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!") + class ToolParserManager: tool_parsers: Dict[str, Type] = {} @@ -121,7 +122,8 @@ def register_module( raise TypeError(f'force must be a boolean, but got {type(force)}') # raise the error ahead of time - if not (name is None or isinstance(name, str) or is_list_of(name, str)): + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): raise TypeError( 'name must be None, an instance of str, or a sequence of str, ' f'but got {type(name)}') From 0a5ddf473d01677d06bda66c504260cf96dd1aba Mon Sep 17 00:00:00 2001 From: sydnash Date: Sun, 29 Sep 2024 08:11:02 +0800 Subject: [PATCH 16/19] [format] --- vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 594024185d98..7e55532bc729 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,15 +1,14 @@ import importlib import importlib.util import os -from collections import abc from typing import Callable, Dict, List, Optional, Sequence, Type, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation) from vllm.logger import init_logger -from vllm.utils import is_list_of from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import is_list_of logger = init_logger(__name__) From 1db530d908bcd24664e1ba3baa2fe77f45c4ffe1 Mon Sep 17 00:00:00 2001 From: sydnash Date: Sun, 29 Sep 2024 10:13:08 +0800 Subject: [PATCH 17/19] [fix] fix the mistral tool call error. recover vllm/version.py and delete vllm/commit_id.py. --- tests/tool_use/utils.py | 3 ++- vllm/commit_id.py | 1 - .../openai/tool_parsers/mistral_tool_parser.py | 4 +--- vllm/transformers_utils/tokenizers/mistral.py | 4 ++++ vllm/version.py | 11 +++++++++++ 5 files changed, 18 insertions(+), 5 deletions(-) delete mode 100644 vllm/commit_id.py create mode 100644 vllm/version.py diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 44076197b8c2..4cb44420c55c 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -79,7 +79,8 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "arguments": [ "--tool-call-parser", "mistral", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), - "--ignore-patterns=\"consolidated.safetensors\"" + "--ignore-patterns=\"consolidated.safetensors\"", + "--tokenizer-mode", "mistral" ], }, "internlm": { diff --git a/vllm/commit_id.py b/vllm/commit_id.py deleted file mode 100644 index 22ff10fffaf6..000000000000 --- a/vllm/commit_id.py +++ /dev/null @@ -1 +0,0 @@ -__commit__ = "6cd5e5b07e4415d064d93b8a66331a097bd9287e" diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c577c26ffeee..eed705762376 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -33,9 +33,7 @@ class MistralToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) - if isinstance(self.model_tokenizer, MistralTokenizer): - self.model_tokenizer = self.model_tokenizer.tokenizer - else: + if not isinstance(self.model_tokenizer, MistralTokenizer): logger.info("Non-Mistral tokenizer detected when using a Mistral " "model...") diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 788133059f12..b5b1e06d4894 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -130,6 +130,10 @@ def is_fast(self) -> bool: def vocab_size(self) -> int: return len(self._vocab) + @property + def vocab(self) -> Dict[str, int]: + return self._vocab + def __len__(self) -> int: return self.vocab_size diff --git a/vllm/version.py b/vllm/version.py new file mode 100644 index 000000000000..66e189dcedf7 --- /dev/null +++ b/vllm/version.py @@ -0,0 +1,11 @@ +try: + from ._version import __version__, __version_tuple__ +except Exception as e: + import warnings + + warnings.warn(f"Failed to read commit hash:\n{e}", + RuntimeWarning, + stacklevel=2) + + __version__ = "dev" + __version_tuple__ = (0, 0, __version__) From dc94a22cdf4fa424408010790fd71dc8559c6482 Mon Sep 17 00:00:00 2001 From: sydnash Date: Sun, 29 Sep 2024 10:45:43 +0800 Subject: [PATCH 18/19] [fix] change vocab property to get_vocab method in mistral_tool_parser.py --- vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py | 2 +- vllm/transformers_utils/tokenizers/mistral.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index eed705762376..44a46f86b05b 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -45,7 +45,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" - self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] + self.bot_token_id = self.model_tokenizer.get_vocab()[self.bot_token] self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) def extract_tool_calls( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index b5b1e06d4894..788133059f12 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -130,10 +130,6 @@ def is_fast(self) -> bool: def vocab_size(self) -> int: return len(self._vocab) - @property - def vocab(self) -> Dict[str, int]: - return self._vocab - def __len__(self) -> int: return self.vocab_size From a2f938f20350fcab9b487ba6e893f03d2304e44a Mon Sep 17 00:00:00 2001 From: sydnash Date: Thu, 3 Oct 2024 09:30:03 +0800 Subject: [PATCH 19/19] [fix] remove --tokenizer-mode mistral for mistral test. fix the system prompt. --- tests/tool_use/test_parallel_tool_calls.py | 2 ++ tests/tool_use/utils.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index 1b031b2b6075..ed7ac8afe1b4 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -16,6 +16,7 @@ @pytest.mark.asyncio async def test_parallel_tool_calls(client: openai.AsyncOpenAI, server_config: ServerConfig): + if not server_config.get("supports_parallel", True): pytest.skip("The {} model doesn't support parallel tool calls".format( server_config["model"])) @@ -143,6 +144,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, @pytest.mark.asyncio async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, server_config: ServerConfig): + if not server_config.get("supports_parallel", True): pytest.skip("The {} model doesn't support parallel tool calls".format( server_config["model"])) diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 4cb44420c55c..ce36515a2381 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -79,9 +79,14 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "arguments": [ "--tool-call-parser", "mistral", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), - "--ignore-patterns=\"consolidated.safetensors\"", - "--tokenizer-mode", "mistral" + "--ignore-patterns=\"consolidated.safetensors\"" ], + "system_prompt": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." }, "internlm": { "model": @@ -94,12 +99,6 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], ], "supports_parallel": False, - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" - " that you have would be helpful to answer a user query, " - "call the tool. Otherwise, answer the user's query directly " - "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." } }