-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Feature] Add support for Llama 3.1 and 3.2 tool use #8343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b85ad93
1abbe93
2d2a329
fa6ebb9
be5eeab
d8d6de4
15d39c8
37336f3
1babd5d
e7d34dc
b22d8c9
3bb941a
f66c0b1
0fabb67
a90055f
932a093
669fe67
6f01abf
871c568
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| {{- bos_token }} | ||
| {%- if custom_tools is defined %} | ||
| {%- set tools = custom_tools %} | ||
| {%- endif %} | ||
| {%- if not tools_in_user_message is defined %} | ||
| {%- set tools_in_user_message = true %} | ||
| {%- endif %} | ||
| {%- if not date_string is defined %} | ||
| {%- set date_string = "26 Jul 2024" %} | ||
| {%- endif %} | ||
| {%- if not tools is defined %} | ||
| {%- set tools = none %} | ||
| {%- endif %} | ||
|
|
||
| {#- This block extracts the system message, so we can slot it into the right place. #} | ||
| {%- if messages[0]['role'] == 'system' %} | ||
| {%- set system_message = messages[0]['content']|trim %} | ||
| {%- set messages = messages[1:] %} | ||
| {%- else %} | ||
| {%- set system_message = "" %} | ||
| {%- endif %} | ||
|
|
||
| {#- System message + builtin tools #} | ||
| {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} | ||
| {%- if builtin_tools is defined or tools is not none %} | ||
| {{- "Environment: ipython\n" }} | ||
| {%- endif %} | ||
| {%- if builtin_tools is defined %} | ||
| {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} | ||
| {%- endif %} | ||
maxdebayser marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| {{- "Cutting Knowledge Date: December 2023\n" }} | ||
maxdebayser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| {{- "Today Date: " + date_string + "\n\n" }} | ||
maxdebayser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| {%- if tools is not none and not tools_in_user_message %} | ||
| {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} | ||
| {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} | ||
| {{- "Do not use variables.\n\n" }} | ||
| {%- for t in tools %} | ||
| {{- t | tojson(indent=4) }} | ||
| {{- "\n\n" }} | ||
| {%- endfor %} | ||
| {%- endif %} | ||
| {{- system_message }} | ||
| {{- "<|eot_id|>" }} | ||
|
|
||
| {#- Custom tools are passed in a user message with some extra guidance #} | ||
| {%- if tools_in_user_message and not tools is none %} | ||
| {#- Extract the first user message so we can plug it in here #} | ||
| {%- if messages | length != 0 %} | ||
| {%- set first_user_message = messages[0]['content']|trim %} | ||
| {%- set messages = messages[1:] %} | ||
| {%- else %} | ||
| {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} | ||
| {%- endif %} | ||
| {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} | ||
| {{- "Given the following functions, please respond with a JSON for a function call " }} | ||
| {{- "with its proper arguments that best answers the given prompt.\n\n" }} | ||
| {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} | ||
| {{- "Do not use variables.\n\n" }} | ||
| {%- for t in tools %} | ||
| {{- t | tojson(indent=4) }} | ||
| {{- "\n\n" }} | ||
| {%- endfor %} | ||
| {{- first_user_message + "<|eot_id|>"}} | ||
| {%- endif %} | ||
|
|
||
| {%- for message in messages %} | ||
| {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is Meta's "recommended" system prompt that's in their docs, but it's very poor. using this template results the model frequently failing to interpret tool call results, and just trying to call a tool again and again, since it includes the instruction "Please respond with a JSON for a function call" |
||
| {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} | ||
| {%- elif 'tool_calls' in message %} | ||
| {%- if not message.tool_calls|length == 1 %} | ||
| {{- raise_exception("This model only supports single tool-calls at once!") }} | ||
| {%- endif %} | ||
| {%- set tool_call = message.tool_calls[0].function %} | ||
| {%- if builtin_tools is defined and tool_call.name in builtin_tools %} | ||
| {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} | ||
| {{- "<|python_tag|>" + tool_call.name + ".call(" }} | ||
| {%- for arg_name, arg_val in tool_call.arguments | items %} | ||
| {{- arg_name + '="' + arg_val + '"' }} | ||
| {%- if not loop.last %} | ||
| {{- ", " }} | ||
| {%- endif %} | ||
| {%- endfor %} | ||
| {{- ")" }} | ||
| {%- else %} | ||
| {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} | ||
| {{- '{"name": "' + tool_call.name + '", ' }} | ||
| {{- '"parameters": ' }} | ||
| {{- tool_call.arguments | tojson }} | ||
| {{- "}" }} | ||
| {%- endif %} | ||
| {%- if builtin_tools is defined %} | ||
| {#- This means we're in ipython mode #} | ||
| {{- "<|eom_id|>" }} | ||
| {%- else %} | ||
| {{- "<|eot_id|>" }} | ||
| {%- endif %} | ||
| {%- elif message.role == "tool" or message.role == "ipython" %} | ||
| {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} | ||
| {%- if message.content is mapping or message.content is iterable %} | ||
| {{- message.content | tojson }} | ||
| {%- else %} | ||
| {{- message.content }} | ||
| {%- endif %} | ||
| {{- "<|eot_id|>" }} | ||
| {%- endif %} | ||
| {%- endfor %} | ||
| {%- if add_generation_prompt %} | ||
| {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} | ||
| {%- endif %} | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| from typing import Dict, List | ||
| import json | ||
| from copy import deepcopy | ||
| from typing import Any, Callable, Dict, List, Optional | ||
|
|
||
| from openai.types.chat import (ChatCompletionMessageParam, | ||
| ChatCompletionToolParam) | ||
|
|
@@ -7,9 +9,56 @@ | |
| from tests.utils import VLLM_PATH | ||
|
|
||
|
|
||
| class ServerConfig(TypedDict): | ||
| class ServerConfig(TypedDict, total=False): | ||
| model: str | ||
| arguments: List[str] | ||
| system_prompt: Optional[str] | ||
| supports_parallel: Optional[bool] | ||
| format_tool_output: Optional[Callable[[str], str]] | ||
|
|
||
|
|
||
| def format_llama_tool_output(output: str) -> str: | ||
| return json.dumps({"output": output}) | ||
|
|
||
|
|
||
| def format_tool_output_id(output: str) -> str: | ||
| return output | ||
|
|
||
|
|
||
| def patch_tool_output(messages: List[Dict[str, Any]], | ||
| config: ServerConfig) -> List[Dict[str, Any]]: | ||
| fmt_fun = config.get("format_tool_output") | ||
| if not fmt_fun: | ||
| return messages | ||
| new_messages = deepcopy(messages) | ||
| for message in new_messages: | ||
| if message["role"] == "tool": | ||
| message["content"] = fmt_fun(message["content"]) | ||
| return new_messages | ||
|
|
||
|
||
|
|
||
| def patch_system_prompt(messages: List[Dict[str, Any]], | ||
| system_prompt: str) -> List[Dict[str, Any]]: | ||
| new_messages = deepcopy(messages) | ||
| if new_messages[0]["role"] == "system": | ||
| new_messages[0]["content"] = system_prompt | ||
| else: | ||
| new_messages.insert(0, {"role": "system", "content": system_prompt}) | ||
| return new_messages | ||
|
|
||
|
|
||
| def ensure_system_prompt(messages: List[Dict[str, Any]], | ||
| config: ServerConfig) -> List[Dict[str, Any]]: | ||
| prompt = config.get("system_prompt") | ||
| if prompt: | ||
| return patch_system_prompt(messages, prompt) | ||
| else: | ||
| return messages | ||
|
|
||
|
|
||
| def adapt_prompt_to_model(messages: List[Dict[str, Any]], | ||
| config: ServerConfig) -> List[Dict[str, Any]]: | ||
| return ensure_system_prompt(patch_tool_output(messages, config), config) | ||
maxdebayser marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| # universal args for all models go here. also good if you need to test locally | ||
|
|
@@ -25,6 +74,25 @@ class ServerConfig(TypedDict): | |
| str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") | ||
| ] | ||
| }, | ||
| "llama": { | ||
| "model": | ||
| "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||
| "arguments": [ | ||
| "--tool-call-parser", "llama", "--chat-template", | ||
| str(VLLM_PATH / "examples/tool_chat_template_llama.jinja") | ||
| ], | ||
| "system_prompt": | ||
| "You are a helpful assistant with tool calling capabilities. " | ||
| "Only reply with a tool call if the function exists in the " | ||
| "library provided by the user. If it doesn't exist, just " | ||
| "reply directly in natural language. When you receive a tool " | ||
| "call response, use the output to format an answer to the " | ||
| "original user question.", | ||
| "supports_parallel": | ||
| False, | ||
| "format_tool_output": | ||
| format_llama_tool_output | ||
| }, | ||
| "mistral": { | ||
| "model": | ||
| "mistralai/Mistral-7B-Instruct-v0.3", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.