Skip to content

Commit e95cd87

Browse files
authored
[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)
1 parent 69e1d2f commit e95cd87

31 files changed

+1347
-407
lines changed

tests/core/block/e2e/test_correctness.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
230230
assert baseline_token_ids == test_token_ids
231231

232232

233+
@pytest.mark.parametrize(
234+
"common_llm_kwargs",
235+
[
236+
{
237+
# Use a small model for a fast test.
238+
"model": "facebook/opt-125m",
239+
240+
# skip cuda graph creation for fast test.
241+
"enforce_eager": True,
242+
"enable_chunked_prefill": True,
243+
"max_num_batched_tokens": 2,
244+
"max_num_seqs": 2,
245+
},
246+
])
247+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
248+
@pytest.mark.parametrize("baseline_llm_kwargs", [
249+
{
250+
"use_v2_block_manager": False,
251+
},
252+
])
253+
@pytest.mark.parametrize("test_llm_kwargs", [
254+
{
255+
"use_v2_block_manager": True,
256+
"num_lookahead_slots": 0,
257+
},
258+
{
259+
"use_v2_block_manager": True,
260+
"num_lookahead_slots": 5,
261+
},
262+
])
263+
@pytest.mark.parametrize("batch_size", [4])
264+
@pytest.mark.parametrize("seed", [1])
265+
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
266+
test_llm_generator, batch_size):
267+
"""Verify that chunked prefill works with BlockManagerV2, with and without
268+
lookahead scheduling.
269+
"""
270+
output_len = 32
271+
temperature = 0.0
272+
273+
prompts = [
274+
"Hello, my name is",
275+
"The president of the United States is",
276+
"The capital of France is",
277+
"The future of AI is",
278+
]
279+
280+
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
281+
282+
sampling_params = SamplingParams(
283+
max_tokens=output_len,
284+
ignore_eos=True,
285+
temperature=temperature,
286+
)
287+
288+
print('Getting token ids with BlockManagerV1')
289+
baseline_token_ids = get_token_ids_from_llm_generator(
290+
baseline_llm_generator, prompts, sampling_params)
291+
292+
print('Getting token ids with BlockManagerV2')
293+
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
294+
prompts, sampling_params)
295+
296+
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
297+
test_token_ids):
298+
assert expected_token_ids == actual_token_ids
299+
300+
assert baseline_token_ids == test_token_ids
301+
302+
233303
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
234304
for llm in llm_generator:
235305
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)

tests/core/utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Optional, Tuple
2+
from typing import Iterable, Optional, Tuple
33

44
from vllm import SamplingParams
55
from vllm.lora.request import LoRARequest
@@ -31,14 +31,17 @@ def create_dummy_prompt(
3131

3232

3333
def create_seq_group(
34-
seq_prompt_len=1024,
35-
seq_output_lens=(128, ),
36-
request_id='0',
37-
seq_id_start=0,
38-
) -> SequenceGroup:
34+
seq_prompt_len: int = 1024,
35+
seq_output_lens: Iterable[int] = (128, ),
36+
request_id: str = '0',
37+
seq_id_start: int = 0,
38+
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
3939

4040
assert len(seq_output_lens) > 0
4141

42+
if sampling_params is None:
43+
sampling_params = SamplingParams()
44+
4245
prompt_token_ids = [0] * seq_prompt_len
4346

4447
seqs = []
@@ -60,7 +63,7 @@ def create_seq_group(
6063
seq_group = SequenceGroup(
6164
request_id=request_id,
6265
seqs=seqs,
63-
sampling_params=SamplingParams(),
66+
sampling_params=sampling_params,
6467
arrival_time=time.time(),
6568
)
6669

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import random
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
from transformers import PreTrainedTokenizer
6+
7+
from tests.core.utils import create_seq_group
8+
from vllm.core.scheduler import Scheduler
9+
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
10+
from vllm.engine.output_processor.stop_checker import StopChecker
11+
from vllm.sampling_params import SamplingParams
12+
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
13+
SequenceStatus)
14+
from vllm.transformers_utils.detokenizer import Detokenizer
15+
from vllm.utils import Counter
16+
17+
18+
@pytest.mark.parametrize("seq_output_len", [128])
19+
@pytest.mark.parametrize("num_new_tokens", [1, 12])
20+
@pytest.mark.skip_global_cleanup
21+
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
22+
"""Verify multi-step decoding appends token ids correctly.
23+
24+
We append token ids and verify all the token ids were appended correctly.
25+
Note that ignore_eos=True.
26+
"""
27+
detokenizer = MagicMock(spec=Detokenizer)
28+
scheduler = MagicMock(spec=Scheduler)
29+
stop_checker = MagicMock(spec=StopChecker)
30+
seq_counter = Counter()
31+
32+
output_processor = MultiStepOutputProcessor(
33+
detokenizer=detokenizer,
34+
scheduler=scheduler,
35+
seq_counter=seq_counter,
36+
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
37+
stop_checker=stop_checker,
38+
)
39+
40+
seq_group = create_seq_group(
41+
seq_prompt_len=1024,
42+
seq_output_lens=[seq_output_len],
43+
sampling_params=SamplingParams(max_tokens=seq_output_len +
44+
num_new_tokens,
45+
ignore_eos=True),
46+
)
47+
48+
seq = seq_group.get_seqs()[0]
49+
seq.status = SequenceStatus.RUNNING
50+
51+
new_token_ids = list(range(num_new_tokens))
52+
53+
outputs = [
54+
SequenceGroupOutput(
55+
samples=[
56+
SequenceOutput(
57+
parent_seq_id=seq.seq_id,
58+
output_token=output_token,
59+
logprobs={output_token: Logprob(0.0)},
60+
)
61+
],
62+
prompt_logprobs=None,
63+
) for output_token in new_token_ids
64+
]
65+
66+
assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
67+
output_processor.process_outputs(seq_group, outputs)
68+
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
69+
70+
71+
@pytest.mark.parametrize("seq_prompt_len", [1024])
72+
@pytest.mark.parametrize("seq_output_len", [128])
73+
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
74+
@pytest.mark.parametrize("max_tokens", [128 + 3])
75+
@pytest.mark.skip_global_cleanup
76+
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
77+
seq_output_len: int, max_tokens: int):
78+
"""Verify tokens after max_tokens are dropped and not appended to the
79+
sequence.
80+
"""
81+
detokenizer = MagicMock(spec=Detokenizer)
82+
scheduler = MagicMock(spec=Scheduler)
83+
stop_checker = MagicMock(spec=StopChecker)
84+
seq_counter = Counter()
85+
86+
output_processor = MultiStepOutputProcessor(
87+
detokenizer=detokenizer,
88+
scheduler=scheduler,
89+
seq_counter=seq_counter,
90+
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
91+
stop_checker=stop_checker,
92+
)
93+
94+
seq_group = create_seq_group(
95+
seq_prompt_len=seq_prompt_len,
96+
seq_output_lens=[seq_output_len],
97+
sampling_params=SamplingParams(max_tokens=max_tokens, ),
98+
)
99+
100+
seq = seq_group.get_seqs()[0]
101+
seq.status = SequenceStatus.RUNNING
102+
103+
new_token_ids = list(range(num_new_tokens))
104+
105+
outputs = [
106+
SequenceGroupOutput(
107+
samples=[
108+
SequenceOutput(
109+
parent_seq_id=seq.seq_id,
110+
output_token=output_token,
111+
logprobs={output_token: Logprob(0.0)},
112+
)
113+
],
114+
prompt_logprobs=None,
115+
) for output_token in new_token_ids
116+
]
117+
118+
assert seq.get_len() == seq_prompt_len + seq_output_len
119+
output_processor.process_outputs(seq_group, outputs)
120+
121+
# Expect the processed sequence to not go over max tokens in len.
122+
assert seq.get_len() == seq_prompt_len + max_tokens
123+
124+
# Expect the correct tokens were appended.
125+
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
126+
assert seq.get_token_ids(
127+
)[-len(expected_appended_tokens):] == expected_appended_tokens
128+
129+
130+
@pytest.mark.parametrize("seq_prompt_len", [1024])
131+
@pytest.mark.parametrize("seq_output_len", [128])
132+
@pytest.mark.parametrize("num_new_tokens", [12])
133+
@pytest.mark.parametrize("seed", list(range(6)))
134+
@pytest.mark.skip_global_cleanup
135+
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
136+
seq_output_len: int, seed: int):
137+
"""Verify the eos token id is included in the sequence, but subsequent
138+
tokens are dropped (not appended to sequence).
139+
"""
140+
random.seed(seed)
141+
detokenizer = MagicMock(spec=Detokenizer)
142+
scheduler = MagicMock(spec=Scheduler)
143+
stop_checker = MagicMock(spec=StopChecker)
144+
seq_counter = Counter()
145+
146+
eos_token_id = 100
147+
148+
output_processor = MultiStepOutputProcessor(
149+
detokenizer=detokenizer,
150+
scheduler=scheduler,
151+
seq_counter=seq_counter,
152+
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
153+
stop_checker=stop_checker,
154+
)
155+
156+
seq_group = create_seq_group(
157+
seq_prompt_len=seq_prompt_len,
158+
seq_output_lens=[seq_output_len],
159+
sampling_params=SamplingParams(
160+
# Ensure enough space.
161+
max_tokens=seq_output_len + num_new_tokens, ),
162+
)
163+
164+
seq = seq_group.get_seqs()[0]
165+
seq.status = SequenceStatus.RUNNING
166+
167+
new_token_ids = list(range(num_new_tokens))
168+
assert eos_token_id not in new_token_ids
169+
eos_index = random.randint(0, len(new_token_ids) - 1)
170+
new_token_ids[eos_index] = eos_token_id
171+
172+
outputs = [
173+
SequenceGroupOutput(
174+
samples=[
175+
SequenceOutput(
176+
parent_seq_id=seq.seq_id,
177+
output_token=output_token,
178+
logprobs={output_token: Logprob(0.0)},
179+
)
180+
],
181+
prompt_logprobs=None,
182+
) for output_token in new_token_ids
183+
]
184+
185+
assert seq.get_len() == seq_prompt_len + seq_output_len
186+
output_processor.process_outputs(seq_group, outputs)
187+
188+
# Expect the processed sequence to not go beyond provided eos.
189+
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
190+
191+
# Expect the correct tokens were appended.
192+
expected_appended_tokens = new_token_ids[:eos_index + 1]
193+
assert seq.get_token_ids(
194+
)[-len(expected_appended_tokens):] == expected_appended_tokens
195+
196+
197+
@pytest.mark.parametrize("seq_prompt_len", [1024])
198+
@pytest.mark.parametrize("seq_output_len", [128])
199+
@pytest.mark.parametrize("num_new_tokens", [12])
200+
@pytest.mark.parametrize("seed", list(range(6)))
201+
@pytest.mark.skip_global_cleanup
202+
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
203+
seq_output_len: int, seed: int):
204+
"""When sampling parameters dictate that we should ignore the eos token id,
205+
ensure all token ids are appended even if the eos token id is emitted.
206+
"""
207+
random.seed(seed)
208+
detokenizer = MagicMock(spec=Detokenizer)
209+
scheduler = MagicMock(spec=Scheduler)
210+
stop_checker = MagicMock(spec=StopChecker)
211+
seq_counter = Counter()
212+
213+
eos_token_id = 100
214+
215+
output_processor = MultiStepOutputProcessor(
216+
detokenizer=detokenizer,
217+
scheduler=scheduler,
218+
seq_counter=seq_counter,
219+
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
220+
stop_checker=stop_checker,
221+
)
222+
223+
seq_group = create_seq_group(
224+
seq_prompt_len=seq_prompt_len,
225+
seq_output_lens=[seq_output_len],
226+
sampling_params=SamplingParams(
227+
# Ensure enough space.
228+
max_tokens=seq_output_len + num_new_tokens,
229+
ignore_eos=True,
230+
),
231+
)
232+
233+
seq = seq_group.get_seqs()[0]
234+
seq.status = SequenceStatus.RUNNING
235+
236+
new_token_ids = list(range(num_new_tokens))
237+
assert eos_token_id not in new_token_ids
238+
eos_index = random.randint(0, len(new_token_ids) - 1)
239+
new_token_ids[eos_index] = eos_token_id
240+
241+
outputs = [
242+
SequenceGroupOutput(
243+
samples=[
244+
SequenceOutput(
245+
parent_seq_id=seq.seq_id,
246+
output_token=output_token,
247+
logprobs={output_token: Logprob(0.0)},
248+
)
249+
],
250+
prompt_logprobs=None,
251+
) for output_token in new_token_ids
252+
]
253+
254+
assert seq.get_len() == seq_prompt_len + seq_output_len
255+
output_processor.process_outputs(seq_group, outputs)
256+
257+
# Expect the processed sequence to go beyond eos.
258+
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
259+
260+
# Expect the correct tokens were appended.
261+
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
262+
seq_output_len]
263+
assert seq.get_token_ids(
264+
)[-len(expected_appended_tokens):] == expected_appended_tokens
265+
266+
267+
def mock_tokenizer(eos_token_id=1000):
268+
tokenizer = MagicMock(spec=PreTrainedTokenizer)
269+
tokenizer.eos_token_id = eos_token_id
270+
return tokenizer

0 commit comments

Comments
 (0)