Skip to content

Commit b56cb3b

Browse files
authored
Merge branch 'main' into address_burstiness
2 parents 6d2e37c + 0b99f5d commit b56cb3b

File tree

30 files changed

+1117
-435
lines changed

30 files changed

+1117
-435
lines changed

.github/CODEOWNERS

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
/vllm/attention @LucasWilkinson
66
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
77
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
8-
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
98
/vllm/model_executor/layers/fused_moe @mgoin
10-
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
119
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
1210
/vllm/model_executor/layers/mamba @tdoublep
1311
/vllm/model_executor/model_loader @22quinn
@@ -26,7 +24,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
2624
/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
2725

2826
# vLLM V1
29-
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
3027
/vllm/v1/attention @LucasWilkinson
3128
/vllm/v1/attention/backends/flashinfer.py @mgoin
3229
/vllm/v1/attention/backends/triton_attn.py @tdoublep

csrc/layernorm_kernels.cu

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "dispatch_utils.h"
33
#include "cub_helpers.h"
44
#include "core/batch_invariant.hpp"
5+
#include "quantization/vectorization_utils.cuh"
56

67
#include <torch/cuda.h>
78
#include <c10/cuda/CUDAGuard.h>
@@ -18,11 +19,22 @@ __global__ void rms_norm_kernel(
1819
const float epsilon, const int num_tokens, const int hidden_size) {
1920
__shared__ float s_variance;
2021
float variance = 0.0f;
22+
const scalar_t* input_row = input + blockIdx.x * input_stride;
2123

22-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
23-
const float x = (float)input[blockIdx.x * input_stride + idx];
24+
constexpr int VEC_SIZE = 8;
25+
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
26+
#pragma unroll
27+
for (int i = 0; i < VEC_SIZE; ++i) {
28+
float x = static_cast<float>(vec.val[i]);
29+
variance += x * x;
30+
}
31+
};
32+
auto scalar_op = [&variance](const scalar_t& val) {
33+
float x = static_cast<float>(val);
2434
variance += x * x;
25-
}
35+
};
36+
vllm::vectorize_read_with_alignment<VEC_SIZE>(
37+
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
2638

2739
using BlockReduce = cub::BlockReduce<float, 1024>;
2840
__shared__ typename BlockReduce::TempStorage reduceStore;

csrc/layernorm_quant_kernels.cu

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "dispatch_utils.h"
1111
#include "cub_helpers.h"
1212
#include "core/batch_invariant.hpp"
13+
#include "quantization/vectorization_utils.cuh"
1314

1415
#include <torch/cuda.h>
1516
#include <c10/cuda/CUDAGuard.h>
@@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel(
2829
__shared__ float s_variance;
2930
float variance = 0.0f;
3031

31-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
32-
const float x = (float)input[blockIdx.x * input_stride + idx];
32+
const scalar_t* input_row = input + blockIdx.x * input_stride;
33+
34+
constexpr int VEC_SIZE = 8;
35+
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
36+
#pragma unroll
37+
for (int i = 0; i < VEC_SIZE; ++i) {
38+
float x = static_cast<float>(vec.val[i]);
39+
variance += x * x;
40+
}
41+
};
42+
auto scalar_op = [&variance](const scalar_t& val) {
43+
float x = static_cast<float>(val);
3344
variance += x * x;
34-
}
45+
};
46+
vllm::vectorize_read_with_alignment<VEC_SIZE>(
47+
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
3548

3649
using BlockReduce = cub::BlockReduce<float, 1024>;
3750
__shared__ typename BlockReduce::TempStorage reduceStore;

docs/features/tool_calling.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,16 @@ Supported models:
352352

353353
Flags: `--tool-call-parser qwen3_xml`
354354

355+
### Olmo 3 Models (`olmo3`)
356+
357+
Olmo 3 models output tool calls in a format that is very similar to the one expected by the `pythonic` parser (see below), with a few differences. Each tool call is a pythonic string, but the parallel tool calls are newline-delimited, and the calls are wrapped within XML tags as `<function_calls>..</function_calls>`. In addition, the parser also allows JSON boolean and null literals (`true`, `false`, and `null`) in addition to the pythonic ones (`True`, `False`, and `None`).
358+
359+
Supported models:
360+
361+
* TODO (will be updated after Olmo 3 release)
362+
363+
Flags: `--tool-call-parser olmo3`
364+
355365
### Models with Pythonic Tool Calls (`pythonic`)
356366

357367
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
from tests.entrypoints.openai.tool_parsers.utils import (
9+
run_tool_extraction,
10+
run_tool_extraction_streaming,
11+
)
12+
from vllm.entrypoints.openai.protocol import FunctionCall
13+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
15+
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
16+
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
17+
SIMPLE_FUNCTION_CALL = FunctionCall(
18+
name="get_weather",
19+
arguments='{"city": "San Francisco", "metric": "celsius"}',
20+
)
21+
MORE_TYPES_FUNCTION_OUTPUT = (
22+
"register_user(name='John Doe', "
23+
"age=37, "
24+
"address={'city': 'San Francisco', 'state': 'CA'}, "
25+
"role=None, "
26+
"passed_test=True, "
27+
"aliases=['John', 'Johnny'])"
28+
)
29+
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
30+
"register_user(name='John Doe', "
31+
"age=37, "
32+
"address={'city': 'San Francisco', 'state': 'CA'}, "
33+
"role=null, "
34+
"passed_test=true, "
35+
"aliases=['John', 'Johnny'])"
36+
)
37+
MORE_TYPES_FUNCTION_CALL = FunctionCall(
38+
name="register_user",
39+
arguments='{"name": "John Doe", '
40+
'"age": 37, '
41+
'"address": {"city": "San Francisco", "state": "CA"}, '
42+
'"role": null, '
43+
'"passed_test": true, '
44+
'"aliases": ["John", "Johnny"]}',
45+
)
46+
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
47+
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
48+
name="get_weather",
49+
arguments="{}",
50+
)
51+
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
52+
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
53+
name="do_something_cool",
54+
arguments='{"additional_data": {}}',
55+
)
56+
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
57+
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
58+
name="do_something_cool",
59+
arguments='{"steps": []}',
60+
)
61+
ESCAPED_STRING_FUNCTION_OUTPUT = (
62+
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
63+
)
64+
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
65+
name="get_weather",
66+
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
67+
)
68+
69+
70+
@pytest.mark.parametrize("streaming", [True, False])
71+
def test_no_tool_call(streaming: bool):
72+
mock_tokenizer = MagicMock()
73+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
74+
model_output = "How can I help you today?"
75+
76+
content, tool_calls = run_tool_extraction(
77+
tool_parser, model_output, streaming=streaming
78+
)
79+
80+
assert content == model_output
81+
assert len(tool_calls) == 0
82+
83+
84+
TEST_CASES = [
85+
pytest.param(
86+
True,
87+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
88+
[SIMPLE_FUNCTION_CALL],
89+
id="simple_streaming",
90+
),
91+
pytest.param(
92+
False,
93+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
94+
[SIMPLE_FUNCTION_CALL],
95+
id="simple_nonstreaming",
96+
),
97+
pytest.param(
98+
True,
99+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
100+
[MORE_TYPES_FUNCTION_CALL],
101+
id="more_types_streaming",
102+
),
103+
pytest.param(
104+
False,
105+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
106+
[MORE_TYPES_FUNCTION_CALL],
107+
id="more_types_nonstreaming",
108+
),
109+
pytest.param(
110+
True,
111+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
112+
[MORE_TYPES_FUNCTION_CALL],
113+
id="more_types_streaming_json_literals",
114+
),
115+
pytest.param(
116+
False,
117+
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
118+
[MORE_TYPES_FUNCTION_CALL],
119+
id="more_types_nonstreaming_json_literals",
120+
),
121+
pytest.param(
122+
True,
123+
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
124+
[PARAMETERLESS_FUNCTION_CALL],
125+
id="parameterless_streaming",
126+
),
127+
pytest.param(
128+
False,
129+
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
130+
[PARAMETERLESS_FUNCTION_CALL],
131+
id="parameterless_nonstreaming",
132+
),
133+
pytest.param(
134+
True,
135+
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
136+
[EMPTY_DICT_FUNCTION_CALL],
137+
id="empty_dict_streaming",
138+
),
139+
pytest.param(
140+
False,
141+
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
142+
[EMPTY_DICT_FUNCTION_CALL],
143+
id="empty_dict_nonstreaming",
144+
),
145+
pytest.param(
146+
True,
147+
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
148+
[EMPTY_LIST_FUNCTION_CALL],
149+
id="empty_list_streaming",
150+
),
151+
pytest.param(
152+
False,
153+
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
154+
[EMPTY_LIST_FUNCTION_CALL],
155+
id="empty_list_nonstreaming",
156+
),
157+
pytest.param(
158+
True,
159+
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
160+
[ESCAPED_STRING_FUNCTION_CALL],
161+
id="escaped_string_streaming",
162+
),
163+
pytest.param(
164+
False,
165+
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
166+
[ESCAPED_STRING_FUNCTION_CALL],
167+
id="escaped_string_nonstreaming",
168+
),
169+
pytest.param(
170+
True,
171+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
172+
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
173+
id="parallel_calls_streaming",
174+
),
175+
pytest.param(
176+
False,
177+
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
178+
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
179+
id="parallel_calls_nonstreaming",
180+
),
181+
]
182+
183+
184+
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
185+
def test_tool_call(
186+
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
187+
):
188+
mock_tokenizer = MagicMock()
189+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
190+
191+
content, tool_calls = run_tool_extraction(
192+
tool_parser, model_output, streaming=streaming
193+
)
194+
195+
assert content is None
196+
assert len(tool_calls) == len(expected_tool_calls)
197+
for actual, expected in zip(tool_calls, expected_tool_calls):
198+
assert actual.type == "function"
199+
assert actual.function == expected
200+
201+
202+
def test_streaming_tool_call_with_large_steps():
203+
mock_tokenizer = MagicMock()
204+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
205+
model_output_deltas = [
206+
"<function_calls>get_weather(city='San",
207+
" Francisco', metric='celsius')\n"
208+
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
209+
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
210+
]
211+
212+
reconstructor = run_tool_extraction_streaming(
213+
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
214+
)
215+
216+
assert reconstructor.other_content == ""
217+
assert len(reconstructor.tool_calls) == 3
218+
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
219+
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
220+
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
221+
222+
223+
@pytest.mark.parametrize("streaming", [False])
224+
def test_regex_timeout_handling(streaming: bool):
225+
"""test regex timeout is handled gracefully"""
226+
mock_tokenizer = MagicMock()
227+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
228+
229+
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
230+
231+
# create a mock regex that raises TimeoutError
232+
mock_regex = MagicMock()
233+
mock_regex.match.side_effect = TimeoutError("Regex timeout")
234+
235+
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
236+
content, tool_calls = run_tool_extraction(
237+
tool_parser, fake_problematic_input, streaming=streaming
238+
)
239+
240+
# should treat as regular text when regex times out
241+
assert content == fake_problematic_input
242+
assert len(tool_calls) == 0
243+
mock_regex.match.assert_called_once()

tests/lora/test_add_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.inputs import TextPrompt
1313
from vllm.lora.request import LoRARequest
1414
from vllm.sampling_params import SamplingParams
15-
from vllm.utils import merge_async_iterators
15+
from vllm.utils.async_utils import merge_async_iterators
1616

1717
MODEL_PATH = "zai-org/chatglm3-6b"
1818
LORA_RANK = 64

tests/utils_/test_async_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import asyncio
4+
from collections.abc import AsyncIterator
5+
6+
import pytest
7+
8+
from vllm.utils.async_utils import merge_async_iterators
9+
10+
11+
async def _mock_async_iterator(idx: int):
12+
try:
13+
while True:
14+
yield f"item from iterator {idx}"
15+
await asyncio.sleep(0.1)
16+
except asyncio.CancelledError:
17+
print(f"iterator {idx} cancelled")
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_merge_async_iterators():
22+
iterators = [_mock_async_iterator(i) for i in range(3)]
23+
merged_iterator = merge_async_iterators(*iterators)
24+
25+
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
26+
async for idx, output in generator:
27+
print(f"idx: {idx}, output: {output}")
28+
29+
task = asyncio.create_task(stream_output(merged_iterator))
30+
await asyncio.sleep(0.5)
31+
task.cancel()
32+
with pytest.raises(asyncio.CancelledError):
33+
await task
34+
35+
for iterator in iterators:
36+
try:
37+
await asyncio.wait_for(anext(iterator), 1)
38+
except StopAsyncIteration:
39+
# All iterators should be cancelled and print this message.
40+
print("Iterator was cancelled normally")
41+
except (Exception, asyncio.CancelledError) as e:
42+
raise AssertionError() from e

0 commit comments

Comments
 (0)