Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions examples/tool_chat_template_llama.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
{%- set date_string = "26 Jul 2024" %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}

{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- endif %}

{#- System message + builtin tools #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if builtin_tools is defined or tools is not none %}
{{- "Environment: ipython\n" }}
{%- endif %}
{%- if builtin_tools is defined %}
{{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}

{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- set first_user_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
{{- "Given the following functions, please respond with a JSON for a function call " }}
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- first_user_message + "<|eot_id|>"}}
{%- endif %}

{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Meta's "recommended" system prompt that's in their docs, but it's very poor. using this template results the model frequently failing to interpret tool call results, and just trying to call a tool again and again, since it includes the instruction "Please respond with a JSON for a function call"

{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
{%- elif 'tool_calls' in message %}
{%- if not message.tool_calls|length == 1 %}
{{- raise_exception("This model only supports single tool-calls at once!") }}
{%- endif %}
{%- set tool_call = message.tool_calls[0].function %}
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
{%- for arg_name, arg_val in tool_call.arguments | items %}
{{- arg_name + '="' + arg_val + '"' }}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- ")" }}
{%- else %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- '{"name": "' + tool_call.name + '", ' }}
{{- '"parameters": ' }}
{{- tool_call.arguments | tojson }}
{{- "}" }}
{%- endif %}
{%- if builtin_tools is defined %}
{#- This means we're in ipython mode #}
{{- "<|eom_id|>" }}
{%- else %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is mapping or message.content is iterable %}
{{- message.content | tojson }}
{%- else %}
{{- message.content }}
{%- endif %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}

17 changes: 10 additions & 7 deletions tests/tool_use/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
import openai
import pytest

from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL
from .utils import (MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL, ServerConfig,
adapt_prompt_to_model)


# test: make sure chat completions without tools provided work even when tools
# are enabled. This makes sure tool call chat templates work, AND that the tool
# parser stream processing doesn't change the output of the model.
@pytest.mark.asyncio
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI):
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
server_config: ServerConfig):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_tokens=150,
model=model_name,
Expand All @@ -34,7 +36,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI):

# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_tokens=150,
model=model_name,
Expand Down Expand Up @@ -77,11 +79,12 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI):
# tools, to make sure we can still get normal chat completion responses
# and that they won't be parsed as tools
@pytest.mark.asyncio
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI):
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
server_config: ServerConfig):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_tokens=150,
model=model_name,
Expand All @@ -102,7 +105,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI):

# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_tokens=150,
model=model_name,
Expand Down
30 changes: 23 additions & 7 deletions tests/tool_use/test_parallel_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,26 @@

from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
WEATHER_TOOL)
WEATHER_TOOL, ServerConfig, adapt_prompt_to_model)


# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
server_config: ServerConfig):

if not server_config.get("supports_parallel", True):
pytest.skip("The {} model doesn't support parallel tool calls".format(
server_config["model"]))

models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
server_config),
temperature=0,
max_tokens=200,
model=model_name,
Expand Down Expand Up @@ -55,7 +62,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
server_config),
temperature=0,
max_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
Expand Down Expand Up @@ -136,11 +144,18 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
server_config: ServerConfig):

if not server_config.get("supports_parallel", True):
pytest.skip("The {} model doesn't support parallel tool calls".format(
server_config["model"]))

models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
messages=adapt_prompt_to_model(MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
server_config),
temperature=0,
max_tokens=200,
model=model_name,
Expand All @@ -158,7 +173,8 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
assert "78" in choice.message.content # Orlando temp in tool response

stream = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
messages=adapt_prompt_to_model(MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
server_config),
temperature=0,
max_tokens=200,
model=model_name,
Expand Down
21 changes: 14 additions & 7 deletions tests/tool_use/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import pytest

from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL, WEATHER_TOOL)
SEARCH_TOOL, WEATHER_TOOL, ServerConfig,
adapt_prompt_to_model)


# test: request a chat completion that should return tool calls, so we know they
# are parsable
@pytest.mark.asyncio
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
async def test_tool_call_and_choice(client: openai.AsyncOpenAI,
server_config: ServerConfig):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_ASKING_FOR_TOOLS,
server_config),
temperature=0,
max_tokens=100,
model=model_name,
Expand Down Expand Up @@ -59,7 +62,8 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_TOOLS,
messages=adapt_prompt_to_model(MESSAGES_ASKING_FOR_TOOLS,
server_config),
temperature=0,
max_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL],
Expand Down Expand Up @@ -136,11 +140,13 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# test: providing tools and results back to model to get a non-tool response
# (streaming/not)
@pytest.mark.asyncio
async def test_tool_call_with_results(client: openai.AsyncOpenAI):
async def test_tool_call_with_results(client: openai.AsyncOpenAI,
server_config: ServerConfig):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
messages=adapt_prompt_to_model(MESSAGES_WITH_TOOL_RESPONSE,
server_config),
temperature=0,
max_tokens=100,
model=model_name,
Expand All @@ -157,7 +163,8 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
assert "98" in choice.message.content # the temperature from the response

stream = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
messages=adapt_prompt_to_model(MESSAGES_WITH_TOOL_RESPONSE,
server_config),
temperature=0,
max_tokens=100,
model=model_name,
Expand Down
72 changes: 70 additions & 2 deletions tests/tool_use/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Dict, List
import json
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional

from openai.types.chat import (ChatCompletionMessageParam,
ChatCompletionToolParam)
Expand All @@ -7,9 +9,56 @@
from tests.utils import VLLM_PATH


class ServerConfig(TypedDict):
class ServerConfig(TypedDict, total=False):
model: str
arguments: List[str]
system_prompt: Optional[str]
supports_parallel: Optional[bool]
format_tool_output: Optional[Callable[[str], str]]


def format_llama_tool_output(output: str) -> str:
return json.dumps({"output": output})


def format_tool_output_id(output: str) -> str:
return output


def patch_tool_output(messages: List[Dict[str, Any]],
config: ServerConfig) -> List[Dict[str, Any]]:
fmt_fun = config.get("format_tool_output")
if not fmt_fun:
return messages
new_messages = deepcopy(messages)
for message in new_messages:
if message["role"] == "tool":
message["content"] = fmt_fun(message["content"])
return new_messages

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this, and why is it necessary?


def patch_system_prompt(messages: List[Dict[str, Any]],
system_prompt: str) -> List[Dict[str, Any]]:
new_messages = deepcopy(messages)
if new_messages[0]["role"] == "system":
new_messages[0]["content"] = system_prompt
else:
new_messages.insert(0, {"role": "system", "content": system_prompt})
return new_messages


def ensure_system_prompt(messages: List[Dict[str, Any]],
config: ServerConfig) -> List[Dict[str, Any]]:
prompt = config.get("system_prompt")
if prompt:
return patch_system_prompt(messages, prompt)
else:
return messages


def adapt_prompt_to_model(messages: List[Dict[str, Any]],
config: ServerConfig) -> List[Dict[str, Any]]:
return ensure_system_prompt(patch_tool_output(messages, config), config)


# universal args for all models go here. also good if you need to test locally
Expand All @@ -25,6 +74,25 @@ class ServerConfig(TypedDict):
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
]
},
"llama": {
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"arguments": [
"--tool-call-parser", "llama", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama.jinja")
],
"system_prompt":
"You are a helpful assistant with tool calling capabilities. "
"Only reply with a tool call if the function exists in the "
"library provided by the user. If it doesn't exist, just "
"reply directly in natural language. When you receive a tool "
"call response, use the output to format an answer to the "
"original user question.",
"supports_parallel":
False,
"format_tool_output":
format_llama_tool_output
},
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes"],
choices=["mistral", "hermes", "llama"],
default=None,
help=
"Select the tool call parser depending on the model that you're using."
Expand Down
Loading