Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit dac4bb3

Browse files
sroy745Robert Shaw
authored andcommitted
[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (vllm-project#5348)
1 parent 08dedd5 commit dac4bb3

14 files changed

+482
-213
lines changed

tests/samplers/test_typical_acceptance_sampler.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
5757
return draft_token_ids
5858

5959

60+
def get_acceptance_sampler(
61+
posterior_threshold: float = 0.03,
62+
posterior_alpha: float = 0.9,
63+
disable_bonus_tokens: bool = False,
64+
strict_mode: bool = False,
65+
) -> TypicalAcceptanceSampler:
66+
"""
67+
Initializes and returns a TypicalAcceptanceSampler.
68+
"""
69+
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
70+
disable_bonus_tokens, strict_mode)
71+
72+
6073
@pytest.mark.parametrize("k", list(range(1, 6)))
6174
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
6275
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@@ -69,7 +82,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
6982
different combinations of k, vocab_size, batch_size and num devices.
7083
"""
7184
torch.set_default_device(device)
72-
typical_acceptance_sampler = TypicalAcceptanceSampler()
85+
typical_acceptance_sampler = get_acceptance_sampler()
7386
typical_acceptance_sampler.init_gpu_tensors(rank=0)
7487
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
7588
bonus_token_ids = torch.randint(low=0,
@@ -81,7 +94,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
8194
size=(batch_size, k),
8295
dtype=torch.int64)
8396
# Verify that sampling succeeds for all cases.
84-
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
97+
typical_acceptance_sampler(target_probs,
98+
bonus_token_ids,
99+
draft_probs=None,
100+
draft_token_ids=draft_token_ids)
85101

86102

87103
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@@ -99,7 +115,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
99115
batch_size = 5
100116
vocab_size = 30_000
101117
torch.set_default_device(device)
102-
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
118+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
103119
typical_acceptance_sampler.init_gpu_tensors(rank=0)
104120
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
105121
bonus_token_ids = torch.randint(low=0,
@@ -130,8 +146,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
130146
oob_token_ids[0][0] = rogue_token_id
131147

132148
with pytest.raises(AssertionError):
133-
typical_acceptance_sampler(target_probs, bonus_token_ids,
134-
draft_token_ids)
149+
typical_acceptance_sampler(target_probs,
150+
bonus_token_ids,
151+
draft_probs=None,
152+
draft_token_ids=draft_token_ids)
135153

136154

137155
@pytest.mark.parametrize("seed", list(range(10)))
@@ -156,7 +174,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
156174
batch_size = 5
157175
vocab_size = 30_000
158176
torch.set_default_device(device)
159-
typical_acceptance_sampler = TypicalAcceptanceSampler(
177+
typical_acceptance_sampler = get_acceptance_sampler(
160178
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
161179
typical_acceptance_sampler.init_gpu_tensors(rank=0)
162180
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@@ -168,9 +186,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
168186
high=vocab_size,
169187
size=(batch_size, 1),
170188
dtype=torch.int64)
171-
output_token_ids = typical_acceptance_sampler(target_probs,
172-
bonus_token_ids,
173-
draft_token_ids)
189+
output_token_ids = typical_acceptance_sampler(
190+
target_probs,
191+
bonus_token_ids,
192+
draft_probs=None,
193+
draft_token_ids=draft_token_ids)
174194
# We are using a uniform target probability distribution.
175195
# For a uniform distribution the entropy is very high and it
176196
# should lead to all draft tokens being accepted. Verify that.
@@ -208,7 +228,7 @@ def test_temperature_zero_target_distribution(seed: int,
208228
vocab_size = 30_000
209229
torch.set_default_device(device)
210230

211-
typical_acceptance_sampler = TypicalAcceptanceSampler(
231+
typical_acceptance_sampler = get_acceptance_sampler(
212232
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
213233
typical_acceptance_sampler.init_gpu_tensors(rank=0)
214234
# Simulate temperature 0 probability distribution for target probabilities
@@ -229,9 +249,11 @@ def test_temperature_zero_target_distribution(seed: int,
229249
# 1.0 tokens in the target distribution we will reject all of them and
230250
# fallback to the greedy sampling for selecting 1 token for each sequence.
231251
# Verify the same.
232-
output_token_ids = typical_acceptance_sampler(target_probs,
233-
bonus_token_ids,
234-
draft_token_ids)
252+
output_token_ids = typical_acceptance_sampler(
253+
target_probs,
254+
bonus_token_ids,
255+
draft_probs=None,
256+
draft_token_ids=draft_token_ids)
235257
assert output_token_ids.shape[0] == batch_size
236258
assert output_token_ids.shape[1] == (k + 1)
237259
assert torch.all(output_token_ids[:, -1] == -1)
@@ -266,7 +288,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
266288
batch_size = 4
267289
vocab_size = 30_000
268290
torch.set_default_device(device)
269-
typical_acceptance_sampler = TypicalAcceptanceSampler(
291+
typical_acceptance_sampler = get_acceptance_sampler(
270292
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
271293
typical_acceptance_sampler.init_gpu_tensors(rank=0)
272294
# For sequences 0 and 2 set the distribution to a temperature
@@ -282,9 +304,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
282304
high=vocab_size,
283305
size=(batch_size, 1),
284306
dtype=torch.int64)
285-
output_token_ids = typical_acceptance_sampler(target_probs,
286-
bonus_token_ids,
287-
draft_token_ids)
307+
output_token_ids = typical_acceptance_sampler(
308+
target_probs,
309+
bonus_token_ids,
310+
draft_probs=None,
311+
draft_token_ids=draft_token_ids)
288312
# verify the shape of output_token_ids
289313
assert output_token_ids.shape[0] == batch_size
290314
assert output_token_ids.shape[1] == (k + 1)
@@ -331,7 +355,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
331355
batch_size = 1
332356
vocab_size = 30_000
333357
torch.set_default_device(device)
334-
typical_acceptance_sampler = TypicalAcceptanceSampler(
358+
typical_acceptance_sampler = get_acceptance_sampler(
335359
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
336360
typical_acceptance_sampler.init_gpu_tensors(rank=0)
337361
# Create a temperature zero target probability distribution and ensure
@@ -344,9 +368,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
344368
high=vocab_size,
345369
size=(batch_size, 1),
346370
dtype=torch.int64)
347-
output_token_ids = typical_acceptance_sampler(target_probs,
348-
bonus_token_ids,
349-
draft_token_ids)
371+
output_token_ids = typical_acceptance_sampler(
372+
target_probs,
373+
bonus_token_ids,
374+
draft_probs=None,
375+
draft_token_ids=draft_token_ids)
350376
assert output_token_ids.shape[0] == batch_size
351377
assert output_token_ids.shape[1] == (k + 1)
352378
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
@@ -362,9 +388,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
362388
batch_size, k, vocab_size, zero_temperature_token_ids)
363389
draft_token_ids = torch.cat(
364390
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
365-
output_token_ids = typical_acceptance_sampler(target_probs,
366-
bonus_token_ids,
367-
draft_token_ids)
391+
output_token_ids = typical_acceptance_sampler(
392+
target_probs,
393+
bonus_token_ids,
394+
draft_probs=None,
395+
draft_token_ids=draft_token_ids)
368396
assert output_token_ids.shape[0] == batch_size
369397
assert output_token_ids.shape[1] == (k + 1)
370398
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
@@ -389,7 +417,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
389417
batch_size = 1
390418
vocab_size = 30_000
391419
torch.set_default_device(device)
392-
typical_acceptance_sampler = TypicalAcceptanceSampler(
420+
typical_acceptance_sampler = get_acceptance_sampler(
393421
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
394422
typical_acceptance_sampler.init_gpu_tensors(rank=0)
395423
# Simulate temperature 0 probability distribution for target
@@ -407,9 +435,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
407435
high=vocab_size,
408436
size=(batch_size, 1),
409437
dtype=torch.int64)
410-
output_token_ids = typical_acceptance_sampler(target_probs,
411-
bonus_token_ids,
412-
draft_token_ids)
438+
output_token_ids = typical_acceptance_sampler(
439+
target_probs,
440+
bonus_token_ids,
441+
draft_probs=None,
442+
draft_token_ids=draft_token_ids)
413443
assert output_token_ids.shape[0] == batch_size
414444
assert output_token_ids.shape[1] == (k + 1)
415445
assert torch.all(output_token_ids[:, 1:-1] == -1)
@@ -423,9 +453,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
423453
posterior_threshold=0.0,
424454
posterior_alpha=0.0)
425455
typical_acceptance_sampler.init_gpu_tensors(rank=0)
426-
output_token_ids = typical_acceptance_sampler(target_probs,
427-
bonus_token_ids,
428-
draft_token_ids)
456+
output_token_ids = typical_acceptance_sampler(
457+
target_probs,
458+
bonus_token_ids,
459+
draft_probs=None,
460+
draft_token_ids=draft_token_ids)
429461
assert output_token_ids.shape[0] == batch_size
430462
assert output_token_ids.shape[1] == (k + 1)
431463
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
@@ -456,7 +488,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
456488
batch_size = 5
457489
vocab_size = 30_000
458490
torch.set_default_device(device)
459-
typical_acceptance_sampler = TypicalAcceptanceSampler(
491+
typical_acceptance_sampler = get_acceptance_sampler(
460492
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
461493
typical_acceptance_sampler.init_gpu_tensors(rank=0)
462494
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)

tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@
1111
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
1212
equality. This gives us good coverage of temp=0.
1313
14+
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
15+
highest probability in the target distribution are accepted. Therefore, we can
16+
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
17+
1418
For temp>0, we rely on unit tests on the rejection sampler to verify that the
1519
output distribution is the same with spec decode vs. no spec decode (this would
16-
be prohibitively expensive to run with a real model).
20+
be prohibitively expensive to run with a real model). Similarly, for the
21+
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
22+
test cases.
1723
1824
NOTE: Speculative decoding's distribution equality requires that the measured
1925
distributions of the target model and proposal model be deterministic given the
@@ -616,3 +622,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
616622
batch_size,
617623
max_output_len=output_len,
618624
force_output_len=True)
625+
626+
627+
@pytest.mark.parametrize(
628+
"common_llm_kwargs",
629+
[{
630+
"model": "JackFram/llama-160m",
631+
632+
# Skip cuda graph recording for fast test.
633+
"enforce_eager": True,
634+
635+
# Required for spec decode.
636+
"use_v2_block_manager": True
637+
}])
638+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
639+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
640+
@pytest.mark.parametrize(
641+
"test_llm_kwargs",
642+
[
643+
{
644+
"speculative_model": "JackFram/llama-68m",
645+
"num_speculative_tokens": k,
646+
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
647+
}
648+
# Try a range of common k.
649+
for k in [1, 2, 3]
650+
])
651+
@pytest.mark.parametrize("batch_size", [1, 32])
652+
@pytest.mark.parametrize(
653+
"output_len",
654+
[
655+
# Use smaller output len for fast test.
656+
32,
657+
])
658+
@pytest.mark.parametrize("seed", [1])
659+
def test_typical_acceptance_sampling(baseline_llm_generator,
660+
test_llm_generator, batch_size: int,
661+
output_len: int):
662+
"""Verify that speculative decoding produces exact equality to without spec
663+
decode with TypicalAcceptanceSampler as the draft token acceptance
664+
sampling method.
665+
"""
666+
run_greedy_equality_correctness_test(baseline_llm_generator,
667+
test_llm_generator,
668+
batch_size,
669+
max_output_len=output_len,
670+
force_output_len=True)

tests/spec_decode/test_dynamic_spec_decode.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import torch
55

66
from tests.nm_utils.utils_skip import should_skip_test_group
7-
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
87
from vllm.sequence import ExecuteModelRequest
98
from vllm.spec_decode.metrics import AsyncMetricsCollector
109
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1110
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
1211
from vllm.spec_decode.top1_proposer import Top1Proposer
1312

13+
from .test_utils import mock_spec_decode_sampler
1414
from .utils import create_batch, mock_worker
1515

1616
if should_skip_test_group(group_name="TEST_SPEC_DECODE"):
@@ -21,23 +21,22 @@
2121
@pytest.mark.parametrize('queue_size', [4])
2222
@pytest.mark.parametrize('batch_size', [1])
2323
@pytest.mark.parametrize('k', [1])
24-
@pytest.mark.parametrize('queue_size', [4])
25-
@pytest.mark.parametrize('batch_size', [1])
26-
@pytest.mark.parametrize('k', [1])
24+
@pytest.mark.parametrize("acceptance_sampler_method",
25+
["rejection_sampler", "typical_acceptance_sampler"])
2726
@torch.inference_mode()
28-
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
27+
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
28+
acceptance_sampler_method: str):
2929
"""Verify that speculative tokens are disabled when the batch size
3030
exceeds the threshold.
3131
"""
3232
disable_by_batch_size = 3
33-
3433
draft_worker = mock_worker(cls=MultiStepWorker)
3534
target_worker = mock_worker()
36-
rejection_sampler = MagicMock(spec=RejectionSampler)
3735
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
3836
worker = SpecDecodeWorker(proposer_worker=draft_worker,
3937
scorer_worker=target_worker,
40-
rejection_sampler=rejection_sampler,
38+
spec_decode_sampler=mock_spec_decode_sampler(
39+
acceptance_sampler_method),
4140
metrics_collector=metrics_collector,
4241
disable_by_batch_size=disable_by_batch_size)
4342

0 commit comments

Comments
 (0)