|
29 | 29 | OpenAIServing, |
30 | 30 | PromptAdapterPath, |
31 | 31 | TextTokensPrompt) |
32 | | -from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, |
33 | | - Llama3JsonToolParser, |
34 | | - MistralToolParser, |
35 | | - ToolParser) |
| 32 | +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager |
36 | 33 | from vllm.inputs import TokensPrompt |
37 | 34 | from vllm.logger import init_logger |
38 | 35 | from vllm.outputs import CompletionOutput, RequestOutput |
@@ -82,15 +79,13 @@ def __init__(self, |
82 | 79 |
|
83 | 80 | self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None |
84 | 81 | if self.enable_auto_tools: |
85 | | - if tool_parser == "mistral": |
86 | | - self.tool_parser = MistralToolParser |
87 | | - elif tool_parser == "hermes": |
88 | | - self.tool_parser = Hermes2ProToolParser |
89 | | - elif tool_parser == "llama3_json": |
90 | | - self.tool_parser = Llama3JsonToolParser |
91 | | - else: |
| 82 | + try: |
| 83 | + self.tool_parser = ToolParserManager.get_tool_parser( |
| 84 | + tool_parser) |
| 85 | + except Exception as e: |
92 | 86 | raise TypeError("Error: --enable-auto-tool-choice requires " |
93 | | - "--tool-call-parser") |
| 87 | + f"tool_parser:'{tool_parser}' which has not " |
| 88 | + "been registered") from e |
94 | 89 |
|
95 | 90 | async def create_chat_completion( |
96 | 91 | self, |
@@ -187,6 +182,10 @@ async def create_chat_completion( |
187 | 182 | raw_request.state.request_metadata = request_metadata |
188 | 183 |
|
189 | 184 | try: |
| 185 | + if self.enable_auto_tools and self.tool_parser: |
| 186 | + request = self.tool_parser(tokenizer).adjust_request( |
| 187 | + request=request) |
| 188 | + |
190 | 189 | if isinstance(prompt, str): |
191 | 190 | prompt_inputs = self._tokenize_prompt_input( |
192 | 191 | request, |
@@ -282,11 +281,11 @@ async def chat_completion_stream_generator( |
282 | 281 | num_choices = 1 if request.n is None else request.n |
283 | 282 | previous_num_tokens = [0] * num_choices |
284 | 283 | finish_reason_sent = [False] * num_choices |
285 | | - |
286 | 284 | num_prompt_tokens = 0 |
287 | 285 |
|
288 | | - tool_parser: Optional[ToolParser] = self.tool_parser( |
289 | | - tokenizer) if self.tool_parser else None |
| 286 | + tool_parsers: List[Optional[ToolParser]] = [ |
| 287 | + self.tool_parser(tokenizer) if self.tool_parser else None |
| 288 | + ] * num_choices |
290 | 289 |
|
291 | 290 | if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): |
292 | 291 | tool_choice_function_name = request.tool_choice.function.name |
@@ -324,7 +323,7 @@ async def chat_completion_stream_generator( |
324 | 323 | # NOTE num_choices defaults to 1 so this usually executes |
325 | 324 | # once per request |
326 | 325 | for i in range(num_choices): |
327 | | - |
| 326 | + tool_parser = tool_parsers[i] |
328 | 327 | choice_data = ChatCompletionResponseStreamChoice( |
329 | 328 | index=i, |
330 | 329 | delta=DeltaMessage( |
@@ -399,6 +398,7 @@ async def chat_completion_stream_generator( |
399 | 398 |
|
400 | 399 | for output in res.outputs: |
401 | 400 | i = output.index |
| 401 | + tool_parser = tool_parsers[i] |
402 | 402 |
|
403 | 403 | if finish_reason_sent[i]: |
404 | 404 | continue |
@@ -446,7 +446,8 @@ async def chat_completion_stream_generator( |
446 | 446 | delta_text=delta_text, |
447 | 447 | previous_token_ids=previous_token_ids, |
448 | 448 | current_token_ids=current_token_ids, |
449 | | - delta_token_ids=output.token_ids)) |
| 449 | + delta_token_ids=output.token_ids, |
| 450 | + request=request)) |
450 | 451 |
|
451 | 452 | # update the previous values for the next iteration |
452 | 453 | previous_texts[i] = current_text |
@@ -685,7 +686,8 @@ async def chat_completion_full_generator( |
685 | 686 | and self.tool_parser: |
686 | 687 |
|
687 | 688 | tool_parser = self.tool_parser(tokenizer) |
688 | | - tool_call_info = tool_parser.extract_tool_calls(output.text) |
| 689 | + tool_call_info = tool_parser.extract_tool_calls( |
| 690 | + output.text, request=request) |
689 | 691 | tools_called = tool_call_info.tools_called |
690 | 692 | if tool_call_info.tools_called: |
691 | 693 | message = ChatMessage(role=role, |
|
0 commit comments