Skip to content

Commit 3e5b882

Browse files
committed
squashed all
1 parent 0d02747 commit 3e5b882

File tree

14 files changed

+472
-119
lines changed

14 files changed

+472
-119
lines changed

tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@
6262
{
6363
"speculative_model": "JackFram/llama-68m",
6464
"num_speculative_tokens": 5,
65+
"enable_chunked_prefill": False,
66+
},
67+
{
68+
# Chunked prefill enabled with small value
69+
# to make sure we get mixed batches.
70+
"speculative_model": "JackFram/llama-68m",
71+
"num_speculative_tokens": 5,
72+
"enable_chunked_prefill": True,
73+
"max_num_batched_tokens": 4,
74+
"max_num_seqs": 4
6575
},
6676
{
6777
# Verify the detokenizer assertions in the test work when spec
@@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
141151
{
142152
"speculative_model": "JackFram/llama-68m",
143153
"num_speculative_tokens": 5,
154+
"enable_chunked_prefill": False,
155+
},
156+
{
157+
"speculative_model": "JackFram/llama-68m",
158+
"num_speculative_tokens": 5,
159+
"enable_chunked_prefill": True,
160+
"max_num_batched_tokens": 4,
161+
"max_num_seqs": 4,
144162
},
145163
])
146164
@pytest.mark.parametrize(
@@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
204222
{
205223
"speculative_model": "JackFram/llama-68m",
206224
"num_speculative_tokens": 5,
225+
"enable_chunked_prefill": False,
226+
},
227+
{
228+
"speculative_model": "JackFram/llama-68m",
229+
"num_speculative_tokens": 5,
230+
"enable_chunked_prefill": True,
231+
"max_num_batched_tokens": 4,
232+
"max_num_seqs": 4
207233
},
208234
])
209235
@pytest.mark.parametrize(
@@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
255281
{
256282
"speculative_model": "JackFram/llama-68m",
257283
"num_speculative_tokens": 5,
284+
"enable_chunked_prefill": False,
285+
},
286+
{
287+
"speculative_model": "JackFram/llama-68m",
288+
"num_speculative_tokens": 5,
289+
"enable_chunked_prefill": True,
290+
"max_num_batched_tokens": 4,
291+
"max_num_seqs": 4
258292
},
259293
])
260294
@pytest.mark.parametrize("max_output_len", [
@@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
300334
{
301335
"speculative_model": "JackFram/llama-68m",
302336
"num_speculative_tokens": 5,
337+
"enable_chunked_prefill": False,
338+
},
339+
{
340+
"speculative_model": "JackFram/llama-68m",
341+
"num_speculative_tokens": 5,
342+
"enable_chunked_prefill": True,
343+
"max_num_batched_tokens": 4,
344+
"max_num_seqs": 4
303345
},
304346
])
305347
@pytest.mark.parametrize("batch_size", [1])
@@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
347389
{
348390
"speculative_model": "JackFram/llama-68m",
349391
"num_speculative_tokens": 5,
392+
"enable_chunked_prefill": False,
393+
},
394+
{
395+
"speculative_model": "JackFram/llama-68m",
396+
"num_speculative_tokens": 5,
397+
"enable_chunked_prefill": True,
398+
"max_num_batched_tokens": 4,
399+
"max_num_seqs": 4
350400
},
351401
])
352402
@pytest.mark.parametrize("batch_size", [32])
@@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
397447
{
398448
"speculative_model": "JackFram/llama-68m",
399449
"num_speculative_tokens": 5,
450+
"enable_chunked_prefill": False,
451+
},
452+
{
453+
"speculative_model": "JackFram/llama-68m",
454+
"num_speculative_tokens": 5,
455+
"enable_chunked_prefill": True,
456+
"max_num_batched_tokens": 4,
457+
"max_num_seqs": 4
400458
},
401459
])
402460
@pytest.mark.parametrize(
@@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
454512
{
455513
"speculative_model": "JackFram/llama-68m",
456514
"num_speculative_tokens": 5,
515+
"enable_chunked_prefill": False,
516+
},
517+
{
518+
"speculative_model": "JackFram/llama-68m",
519+
"num_speculative_tokens": 5,
520+
"enable_chunked_prefill": True,
521+
"max_num_batched_tokens": 4,
522+
"max_num_seqs": 4
457523
},
458524
])
459525
@pytest.mark.parametrize("batch_size", [2])
@@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
503569
# Artificially limit the draft model max model len; this forces vLLM
504570
# to skip speculation once the sequences grow beyond 32-k tokens.
505571
"speculative_max_model_len": 32,
572+
"enable_chunked_prefill": False,
573+
},
574+
{
575+
"speculative_model": "JackFram/llama-68m",
576+
"num_speculative_tokens": 5,
577+
"enable_chunked_prefill": True,
578+
"max_num_batched_tokens": 4,
579+
"max_num_seqs": 4,
580+
"speculative_max_model_len": 32,
506581
},
507582
])
508583
@pytest.mark.parametrize("batch_size", [8])
@@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
551626
"speculative_model": "JackFram/llama-68m",
552627
"num_speculative_tokens": 5,
553628
"speculative_disable_by_batch_size": 2,
629+
"enable_chunked_prefill": False,
630+
},
631+
{
632+
"speculative_model": "JackFram/llama-68m",
633+
"num_speculative_tokens": 5,
634+
"speculative_disable_by_batch_size": 2,
635+
"enable_chunked_prefill": True,
636+
"max_num_batched_tokens": 4,
637+
"max_num_seqs": 4,
554638
},
555639
])
556640
@pytest.mark.parametrize("batch_size", [8])
@@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
590674
{
591675
"speculative_model": "JackFram/llama-68m",
592676
"num_speculative_tokens": k,
677+
"enable_chunked_prefill": False,
593678
}
594679
# Try a range of common k, as well as large speculation.
595680
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
596-
])
681+
] + [{
682+
"speculative_model": "JackFram/llama-68m",
683+
"num_speculative_tokens": k,
684+
"enable_chunked_prefill": True,
685+
"max_num_batched_tokens": 4,
686+
"max_num_seqs": 4,
687+
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
597688
@pytest.mark.parametrize("batch_size", [2])
598689
@pytest.mark.parametrize(
599690
"output_len",
@@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
636727
{
637728
"speculative_model": "JackFram/llama-68m",
638729
"num_speculative_tokens": k,
639-
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
730+
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
731+
"enable_chunked_prefill": False
640732
}
641733
# Try a range of common k.
642734
for k in [1, 2, 3]
643-
])
735+
] + [{
736+
"speculative_model": "JackFram/llama-68m",
737+
"num_speculative_tokens": k,
738+
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
739+
"enable_chunked_prefill": True,
740+
"max_num_batched_tokens": 4,
741+
"max_num_seqs": 4
742+
} for k in [1, 2, 3]])
644743
@pytest.mark.parametrize("batch_size", [1, 32])
645744
@pytest.mark.parametrize(
646745
"output_len",

tests/spec_decode/e2e/test_ngram_correctness.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@
4949
"speculative_model": "[ngram]",
5050
"num_speculative_tokens": 5,
5151
"ngram_prompt_lookup_max": 3,
52+
"enable_chunked_prefill": False,
53+
},
54+
{
55+
"speculative_model": "[ngram]",
56+
"num_speculative_tokens": 5,
57+
"ngram_prompt_lookup_max": 3,
58+
"enable_chunked_prefill": True,
59+
"speculative_disable_mqa_scorer": True,
60+
"max_num_batched_tokens": 4,
61+
"max_num_seqs": 4
5262
},
5363
])
5464
@pytest.mark.parametrize("output_len", [
@@ -151,6 +161,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
151161
"speculative_model": "[ngram]",
152162
"num_speculative_tokens": 5,
153163
"ngram_prompt_lookup_max": 3,
164+
"enable_chunked_prefill": False,
165+
},
166+
{
167+
"speculative_model": "[ngram]",
168+
"num_speculative_tokens": 5,
169+
"ngram_prompt_lookup_max": 3,
170+
"enable_chunked_prefill": True,
171+
"speculative_disable_mqa_scorer": True,
172+
"max_num_batched_tokens": 4,
173+
"max_num_seqs": 4
154174
},
155175
])
156176
@pytest.mark.parametrize(
@@ -251,6 +271,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
251271
"num_speculative_tokens": 5,
252272
"ngram_prompt_lookup_max": 3,
253273
"speculative_disable_by_batch_size": 4
274+
}, {
275+
"speculative_model": "[ngram]",
276+
"num_speculative_tokens": 5,
277+
"ngram_prompt_lookup_max": 3,
278+
"speculative_disable_by_batch_size": 4,
279+
"enable_chunked_prefill": True,
280+
"speculative_disable_mqa_scorer": True,
281+
"max_num_batched_tokens": 4,
282+
"max_num_seqs": 4
254283
}])
255284
@pytest.mark.parametrize("batch_size", [1, 5])
256285
@pytest.mark.parametrize(

tests/spec_decode/test_scorer.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
4646
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
4747
@pytest.mark.parametrize('mixed_propose_len', [True])
4848
@pytest.mark.parametrize('device', ['cuda'])
49+
@pytest.mark.parametrize('prefill_chunking', [False, True])
4950
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
50-
mixed_propose_len: bool, device: str) -> None:
51+
mixed_propose_len: bool, device: str,
52+
prefill_chunking: bool) -> None:
5153
"""
5254
Compare the batch expansion scorer and mqa scorer return the same score.
5355
We test for both queries with the same propose length and different
54-
propose length.
56+
propose length, as well as mixed prefill-decode batches.
5557
"""
5658
seed = 0
5759
block_size = 32
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
6769
if not mixed_propose_len:
6870
propose_lens = [max_propose_len] * batch_size
6971
else:
70-
non_zero_cnt = random.randint(0, batch_size)
72+
# There must be at least 1 decode request, otherwise
73+
# we have nothing to score (`_run_no_spec`).
74+
non_zero_cnt = random.randint(1, batch_size)
7175
propose_lens = [max_propose_len
7276
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
7377
random.shuffle(propose_lens)
7478

75-
proposals = create_proposal(propose_lens, vocab_size, device)
7679
seq_group_metadatalist, _, _ = create_batch(batch_size,
7780
max_propose_len,
7881
block_size=block_size,
7982
num_gpu_blocks=num_gpu_blocks)
83+
84+
if mixed_propose_len and prefill_chunking and (n_prefills :=
85+
batch_size - non_zero_cnt):
86+
prefill, _, _ = create_batch(n_prefills,
87+
None,
88+
prefill_chunk_size=4,
89+
block_size=block_size,
90+
num_gpu_blocks=num_gpu_blocks,
91+
seq_ids=list(
92+
range(batch_size,
93+
batch_size + n_prefills)))
94+
# re-order to guarantee prefill|decode order
95+
target_group_metadatalist = [
96+
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
97+
if p > 0
98+
]
99+
seq_group_metadatalist = prefill + target_group_metadatalist
100+
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
101+
102+
proposals = create_proposal(propose_lens, vocab_size, device)
80103
requests = ExecuteModelRequest(seq_group_metadatalist,
81104
num_lookahead_slots=max_propose_len)
82105

tests/spec_decode/test_spec_decode_worker.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.model_executor.layers.sampler import SamplerOutput
1111
from vllm.model_executor.utils import set_random_seed
1212
from vllm.sequence import ExecuteModelRequest, SequenceOutput
13+
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
1314
from vllm.spec_decode.interfaces import SpeculativeProposals
1415
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
1516
SpecDecodeWorkerMetrics)
@@ -819,3 +820,84 @@ def test_handle_finished_requests():
819820
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
820821
assert worker._seq_with_bonus_token_in_last_step == \
821822
{4,5,10}
823+
824+
825+
@pytest.mark.parametrize('k', [3])
826+
@pytest.mark.parametrize('batch_size', [2, 32])
827+
@pytest.mark.parametrize("batch_composition",
828+
["prefill_only", "decode_only", "mixed"])
829+
@torch.inference_mode()
830+
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
831+
"""
832+
Verify SpecDecodeWorker calls match the expected flow.
833+
"""
834+
vocab_size = 32_000
835+
draft_worker = mock_worker(cls=MultiStepWorker)
836+
target_worker = mock_worker()
837+
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
838+
worker = SpecDecodeWorker(draft_worker,
839+
target_worker,
840+
mock_spec_decode_sampler("rejection_sampler"),
841+
disable_logprobs=False,
842+
metrics_collector=metrics_collector)
843+
exception_secret = 'artificial stop'
844+
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
845+
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
846+
847+
# Create batch with combination of terminal/non-terminal prefill chunks
848+
# and decodes (different seq_ids).
849+
decodes, _, _ = create_batch(batch_size, k)
850+
# Pre-chunking here, get 'batch_size' chunks.
851+
prefill, _, _ = create_batch(batch_size,
852+
k,
853+
prefill_chunk_size=4,
854+
seq_ids=list(range(batch_size,
855+
batch_size * 2)))
856+
857+
if batch_composition == "prefill_only":
858+
n_prefills = batch_size
859+
elif batch_composition == "decode_only":
860+
n_prefills = 0
861+
else:
862+
n_prefills = random.randint(1, batch_size - 1)
863+
n_decodes = batch_size - n_prefills
864+
865+
prefill = random.sample(prefill, n_prefills)
866+
decodes = random.sample(decodes, n_decodes)
867+
target_group_metadata_list = prefill + decodes
868+
execute_model_req = ExecuteModelRequest(
869+
seq_group_metadata_list=target_group_metadata_list,
870+
num_lookahead_slots=k)
871+
872+
target_token_ids = torch.randint(low=0,
873+
high=vocab_size,
874+
size=(1, batch_size * (k + 1)),
875+
dtype=torch.int64,
876+
device='cuda')
877+
target_token_probs = torch.rand(1,
878+
batch_size * (k + 1),
879+
vocab_size,
880+
dtype=torch.float32,
881+
device='cuda')
882+
target_token_logprobs = torch.rand(1,
883+
batch_size * (k + 1),
884+
vocab_size,
885+
dtype=torch.float32,
886+
device='cuda')
887+
target_output = create_sampler_output_list(target_token_ids,
888+
target_token_probs,
889+
target_token_logprobs)
890+
891+
target_worker.execute_model.return_value = [target_output[0]]
892+
893+
if not len(decodes):
894+
worker.execute_model(execute_model_req=execute_model_req)
895+
# no spec run (prefill only)
896+
draft_worker.execute_model.assert_called_once_with(execute_model_req)
897+
target_worker.execute_model.assert_called_once_with(execute_model_req)
898+
else:
899+
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
900+
with pytest.raises(ValueError, match=exception_secret):
901+
worker.execute_model(execute_model_req=execute_model_req)
902+
# but first draft still counted
903+
assert draft_worker.get_spec_proposals.call_count == 1

0 commit comments

Comments
 (0)