Skip to content

Commit 58142aa

Browse files
robertgshaw2-redhatnjhillVarun Sundar Rabindranathtlrmchlsmth
authored andcommitted
[V1] AsyncLLM Implementation (vllm-project#9826)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: [email protected] <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Nick Hill <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
1 parent f7487d4 commit 58142aa

29 files changed

+2412
-727
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ steps:
165165
# OOM in the CI unless we run this separately
166166
- pytest -v -s tokenization
167167

168+
- label: V1 Test
169+
#mirror_hardwares: [amd]
170+
source_file_dependencies:
171+
- vllm/
172+
- tests/v1
173+
commands:
174+
- pytest -v -s v1
175+
168176
- label: Examples Test # 15min
169177
working_dir: "/vllm-workspace/examples"
170178
#mirror_hardwares: [amd]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
This file test accuracy of the vLLM server via LMEval.
3+
It uses local-completions, which interacts with vLLM
4+
through the OAI API with N concurrent connections.
5+
This simulates real work usage of the API and makes
6+
sure that the zmq frontend mp RPC message passing and
7+
AsyncLLMEngine are working correctly.
8+
"""
9+
10+
import lm_eval
11+
import pytest
12+
13+
from vllm.platforms import current_platform
14+
15+
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
16+
NUM_CONCURRENT = 500
17+
TASK = "gsm8k"
18+
FILTER = "exact_match,strict-match"
19+
RTOL = 0.03
20+
EXPECTED_VALUE = 0.58
21+
22+
23+
def run_test():
24+
"""Run the end to end accuracy test."""
25+
26+
model_args = f"pretrained={MODEL_NAME},max_model_len=2048"
27+
28+
results = lm_eval.simple_evaluate(
29+
model="vllm",
30+
model_args=model_args,
31+
tasks="gsm8k",
32+
batch_size="auto",
33+
)
34+
35+
measured_value = results["results"][TASK][FILTER]
36+
assert (measured_value - RTOL < EXPECTED_VALUE
37+
and measured_value + RTOL > EXPECTED_VALUE
38+
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
39+
40+
41+
@pytest.mark.skipif(not current_platform.is_cuda(),
42+
reason="V1 is currently only supported on CUDA.")
43+
def test_lm_eval_accuracy_v1_engine(monkeypatch):
44+
"""Run with the V1 Engine."""
45+
46+
with monkeypatch.context() as m:
47+
m.setenv("VLLM_USE_V1", "1")
48+
run_test()
49+
50+
51+
def test_lm_eval_accuracy_v0_engine(monkeypatch):
52+
"""Run with the V0 Engine."""
53+
54+
with monkeypatch.context() as m:
55+
m.setenv("VLLM_USE_V1", "0")
56+
run_test()

tests/entrypoints/openai/test_accuracy.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
MAX_WAIT_SECONDS = 600
3838

3939

40-
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
41-
def test_lm_eval_accuracy(more_args):
40+
def run_test(more_args):
41+
"""Run the end to end accuracy test."""
42+
4243
args = list(DEFAULT_ARGS)
4344
args.extend(more_args)
44-
4545
print(f"Running with: {args}")
4646

4747
with RemoteOpenAIServer(
@@ -64,3 +64,22 @@ def test_lm_eval_accuracy(more_args):
6464
assert (measured_value - RTOL < EXPECTED_VALUE
6565
and measured_value + RTOL > EXPECTED_VALUE
6666
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
67+
68+
69+
@pytest.mark.skipif(not current_platform.is_cuda(),
70+
reason="V1 currently only supported on CUDA")
71+
def test_lm_eval_accuracy_v1_engine(monkeypatch):
72+
"""Run with the V1 Engine."""
73+
74+
with monkeypatch.context() as m:
75+
m.setenv("VLLM_USE_V1", "1")
76+
run_test([])
77+
78+
79+
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
80+
def test_lm_eval_accuracy_v0_engine(monkeypatch, more_args):
81+
"""Run with the V0 Engine."""
82+
83+
with monkeypatch.context() as m:
84+
m.setenv("VLLM_USE_V1", "0")
85+
run_test(more_args)
File renamed without changes.

tests/v1/engine/__init__.py

Whitespace-only changes.

tests/v1/engine/test_async_llm.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import asyncio
2+
from typing import Tuple
3+
4+
import pytest
5+
6+
from vllm import SamplingParams
7+
from vllm.engine.arg_utils import AsyncEngineArgs
8+
from vllm.platforms import current_platform
9+
from vllm.v1.engine.async_llm import AsyncLLM
10+
11+
if not current_platform.is_cuda():
12+
pytest.skip(reason="V1 currently only supported on CUDA.",
13+
allow_module_level=True)
14+
15+
ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
16+
disable_log_requests=True)
17+
18+
19+
async def generate(engine: AsyncLLM, request_id: str,
20+
max_tokens: int) -> Tuple[int, str]:
21+
count = 0
22+
async for _ in engine.generate(request_id=request_id,
23+
prompt="Hello my name is Robert and",
24+
sampling_params=SamplingParams(
25+
max_tokens=max_tokens, temperature=0)):
26+
27+
count += 1
28+
await asyncio.sleep(0.)
29+
30+
return count, request_id
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_load(monkeypatch):
35+
with monkeypatch.context() as m:
36+
m.setenv("VLLM_USE_V1", "1")
37+
38+
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
39+
40+
NUM_REQUESTS = 10000
41+
NUM_EXPECTED_TOKENS = 10
42+
43+
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
44+
45+
# Create concurrent requests.
46+
tasks = []
47+
for request_id in request_ids:
48+
tasks.append(
49+
asyncio.create_task(
50+
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
51+
52+
# Confirm that we got all the EXPECTED tokens from the requests.
53+
failed_request_id = None
54+
tokens = None
55+
for task in tasks:
56+
num_generated_tokens, request_id = await task
57+
if (num_generated_tokens != NUM_EXPECTED_TOKENS
58+
and failed_request_id is None):
59+
failed_request_id = request_id
60+
tokens = num_generated_tokens
61+
62+
assert failed_request_id is None, (
63+
f"{failed_request_id} generated {tokens} but "
64+
f"expected {NUM_EXPECTED_TOKENS}")
65+
66+
engine.shutdown()
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from typing import List
2+
3+
import pytest
4+
from transformers import AutoTokenizer
5+
6+
from vllm.sampling_params import RequestOutputKind
7+
from vllm.v1.engine import EngineCoreOutput
8+
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest
9+
10+
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
11+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
12+
13+
FULL_STRINGS = [
14+
"My name is Robert from Neural Magic and I love working on vLLM so much!",
15+
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
16+
"Nick is the name of my brother in addition to my colleague from Red Hat.",
17+
]
18+
19+
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
20+
21+
FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS]
22+
PROMPT_LEN = 5
23+
PROMPT_TOKENS = [
24+
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
25+
]
26+
GENERATION_TOKENS = [
27+
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
28+
]
29+
PROMPT_STRINGS = [
30+
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
31+
for prompt_tokens in PROMPT_TOKENS
32+
]
33+
PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS]
34+
GENERATION_STRINGS = [
35+
text[prompt_len:]
36+
for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN)
37+
]
38+
39+
40+
class MockEngineCore:
41+
"""Mock outputs form premade tokens lists."""
42+
43+
def __init__(self, tokens_list: List[List[int]]):
44+
self.tokens_list = tokens_list
45+
self.current_idx = 0
46+
47+
def get_outputs(self) -> List[EngineCoreOutput]:
48+
token_idx = self.current_idx
49+
self.current_idx += 1
50+
51+
outputs = []
52+
for req_idx, token_ids in enumerate(self.tokens_list):
53+
if len(token_ids) > token_idx:
54+
output = EngineCoreOutput(request_id=f"request-{req_idx}",
55+
new_token_ids=[token_ids[token_idx]],
56+
finished=False)
57+
if token_idx == len(token_ids) - 1:
58+
output.finished = True
59+
output.finish_reason = "stopped"
60+
outputs.append(output)
61+
62+
return outputs
63+
64+
65+
@pytest.mark.parametrize(
66+
"request_output_kind",
67+
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
68+
def test_incremental_detokenization(request_output_kind: RequestOutputKind):
69+
detokenizer = Detokenizer(TOKENIZER_NAME)
70+
engine_core = MockEngineCore(GENERATION_TOKENS)
71+
72+
# Make N requests.
73+
requests = [
74+
DetokenizerRequest(
75+
request_id=f"request-{idx}",
76+
prompt=prompt,
77+
prompt_token_ids=prompt_tokens,
78+
skip_special_tokens=False,
79+
spaces_between_special_tokens=False,
80+
output_kind=request_output_kind,
81+
stop=[],
82+
include_stop_str_in_output=False,
83+
) for idx, (
84+
prompt,
85+
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
86+
]
87+
88+
# Add requests to the detokenizer.
89+
for request in requests:
90+
detokenizer.add_request(request)
91+
92+
gen_strings = {}
93+
gen_tokens = {}
94+
while True:
95+
# Mock output from the EngineCore.
96+
outputs = engine_core.get_outputs()
97+
if len(outputs) == 0:
98+
break
99+
100+
# Step the Detokenizer.
101+
request_outputs, requests_to_abort = detokenizer.step(outputs)
102+
assert len(requests_to_abort) == 0
103+
104+
# Update tracking.
105+
for request_output in request_outputs:
106+
request_id = request_output.request_id
107+
new_text = request_output.outputs[0].text
108+
new_tokens = request_output.outputs[0].token_ids
109+
if request_id not in gen_strings:
110+
gen_strings[request_id] = new_text
111+
gen_tokens[request_id] = new_tokens
112+
else:
113+
gen_strings[request_id] += new_text
114+
gen_tokens[request_id].extend(new_tokens)
115+
116+
# Confirmed tracked values matches what we expected.
117+
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
118+
zip(GENERATION_STRINGS, GENERATION_TOKENS)):
119+
gen_str = gen_strings[f"request-{idx}"]
120+
gen_toks = gen_tokens[f"request-{idx}"]
121+
122+
assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}"
123+
assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}"
124+
125+
assert detokenizer.get_num_unfinished_requests() == 0
126+
assert not detokenizer.has_unfinished_requests()
127+
128+
129+
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
130+
def test_stop_string(include_stop_str_in_output: bool):
131+
detokenizer = Detokenizer(TOKENIZER_NAME)
132+
engine_core = MockEngineCore(GENERATION_TOKENS)
133+
134+
# Make N requests.
135+
requests = [
136+
DetokenizerRequest(
137+
request_id=f"request-{idx}",
138+
prompt=prompt,
139+
prompt_token_ids=prompt_tokens,
140+
skip_special_tokens=False,
141+
spaces_between_special_tokens=False,
142+
output_kind=RequestOutputKind.DELTA,
143+
stop=STOP_STRINGS,
144+
include_stop_str_in_output=include_stop_str_in_output,
145+
) for idx, (
146+
prompt,
147+
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
148+
]
149+
150+
# Add requests to the detokenizer.
151+
for request in requests:
152+
detokenizer.add_request(request)
153+
154+
gen_strings = {}
155+
aborted = []
156+
while True:
157+
# Mock output from the EngineCore.
158+
outputs = engine_core.get_outputs()
159+
if len(outputs) == 0:
160+
break
161+
162+
# Step the Detokenizer.
163+
request_outputs, requests_to_abort = detokenizer.step(outputs)
164+
for request_output in request_outputs:
165+
# If aborted, we should not get a request output.
166+
assert request_output.request_id not in aborted
167+
aborted.extend(requests_to_abort)
168+
169+
# Update tracking.
170+
for request_output in request_outputs:
171+
if request_output.finished:
172+
assert request_output.outputs[0].finish_reason == "stop"
173+
174+
request_id = request_output.request_id
175+
new_text = request_output.outputs[0].text
176+
if request_id not in gen_strings:
177+
gen_strings[request_id] = new_text
178+
else:
179+
gen_strings[request_id] += new_text
180+
181+
# Confirmed tracked values matches what we expected.
182+
for idx, (ref_gen_str,
183+
stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)):
184+
185+
# Request should be aborted.
186+
request_id = f"request-{idx}"
187+
assert request_id in aborted
188+
189+
# Collected values that were generated.
190+
gen_str = gen_strings[request_id]
191+
192+
# Construct reference strings.
193+
stop_str_idx = ref_gen_str.find(stop_str)
194+
ref_str_exc_stop = ref_gen_str[:stop_str_idx]
195+
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str
196+
197+
if include_stop_str_in_output:
198+
assert gen_str == ref_str_inc_stop, (
199+
f"{gen_str=}, {ref_str_inc_stop=}")
200+
else:
201+
assert gen_str == ref_str_exc_stop, (
202+
f"{gen_str=}, {ref_str_exc_stop=}")
203+
204+
assert detokenizer.get_num_unfinished_requests() == 0
205+
assert not detokenizer.has_unfinished_requests()

0 commit comments

Comments
 (0)