Skip to content

Commit d2880ea

Browse files
committed
Rework, including spec decoding cases
1 parent 3af88df commit d2880ea

15 files changed

+104
-88
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
157157
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
158158
per_test_common_llm_kwargs, distinct_llm_kwargs,
159159
seed):
160+
print("CREATE LLM GENERATOR")
160161
kwargs = {
161162
**common_llm_kwargs,
162163
**per_test_common_llm_kwargs,

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
import pytest
2323

24-
from .conftest import run_greedy_equality_correctness_test, run_equality_correctness_test
24+
from .conftest import (run_equality_correctness_test,
25+
run_greedy_equality_correctness_test)
2526

2627
# main model
2728
MAIN_MODEL = "JackFram/llama-160m"
@@ -94,31 +95,39 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
9495
9596
# Main model
9697
"model": MAIN_MODEL,
98+
99+
# Speculative model
100+
"speculative_model": SPEC_MODEL,
97101
}])
98102
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
99-
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
100-
@pytest.mark.parametrize("test_llm_kwargs", [
101-
{
102-
"speculative_model": SPEC_MODEL,
103-
},
104-
])
105-
@pytest.mark.parametrize("output_len", [
106-
128,
107-
])
103+
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
104+
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
105+
@pytest.mark.parametrize("output_len", [64])
108106
@pytest.mark.parametrize("batch_size", [1, 32])
109107
@pytest.mark.parametrize("temperature", [0.1, 1.0])
110-
@pytest.mark.parametrize("seed", [1])
108+
@pytest.mark.parametrize("seed", [None])
111109
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
112-
batch_size: int, output_len: int, temperature: float):
110+
batch_size: int, output_len: int,
111+
temperature: float):
113112
"""Verify seeded runs produce the same output."""
114113
run_equality_correctness_test(baseline_llm_generator,
115-
baseline_llm_generator,
114+
test_llm_generator,
116115
batch_size,
117116
max_output_len=output_len,
118117
temperature=temperature,
119118
seeded=True,
120119
force_output_len=True)
121120

121+
# Ensure this same test does fail if we _don't_ include per-request seeds
122+
with pytest.raises(AssertionError):
123+
run_equality_correctness_test(baseline_llm_generator,
124+
test_llm_generator,
125+
batch_size,
126+
max_output_len=output_len,
127+
temperature=temperature,
128+
seeded=False,
129+
force_output_len=True)
130+
122131

123132
@pytest.mark.parametrize(
124133
"common_llm_kwargs",

tests/spec_decode/e2e/test_seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"output_len",
3030
[
3131
# Use smaller output len for fast test.
32-
10,
32+
20,
3333
])
3434
@pytest.mark.parametrize("seed", [None])
3535
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,

vllm/model_executor/layers/rejection_sampler.py

Lines changed: 39 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import cached_property
2-
from typing import List, Optional, Tuple
2+
from typing import Dict, List, Optional, Tuple
33

44
import torch
55
import torch.jit
@@ -36,7 +36,7 @@ def forward(
3636
bonus_token_ids: torch.Tensor,
3737
draft_probs: torch.Tensor,
3838
draft_token_ids: torch.Tensor,
39-
generators: List[Optional[torch.Generator]],
39+
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
4040
) -> torch.Tensor:
4141
"""Sample token ids using rejection sampling. This accepts or rejects
4242
tokens proposed by the draft model using the probability of each token
@@ -66,6 +66,9 @@ def forward(
6666
probabilities.
6767
shape = [batch_size, num_speculative_tokens]
6868
69+
seeded_seqs: Dict of batch row index to torch generator, for
70+
sequences using seeded generation.
71+
6972
Returns:
7073
output_token_ids: The token ids sampled via rejection sampling,
7174
or -1 if unable to sample a token because the previous token
@@ -83,7 +86,7 @@ def forward(
8386
target_probs,
8487
draft_probs,
8588
draft_token_ids,
86-
generators,
89+
seeded_seqs,
8790
))
8891

8992
output_token_ids = self._create_output(
@@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling(
100103
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
101104
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
102105
draft_token_ids: torch.Tensor, # [batch_size, k]
103-
generators: List[Optional[torch.Generator]],
106+
seeded_seqs: Optional[Dict[int, torch.Generator]],
104107
) -> Tuple[torch.Tensor, torch.Tensor]:
105108
"""Perform modified rejection sampling on each sequence.
106109
@@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling(
117120

118121
# shape [batch_size, k]
119122
accepted = self._get_accepted(target_probs, draft_probs,
120-
draft_token_ids, generators)
123+
draft_token_ids, seeded_seqs)
121124

122125
recovered_probs = self._get_recovered_probs(
123126
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
124127

125-
seed_indices, non_seed_indices = self._split_batch_by_seeded(
126-
generators, k=k)
127-
128128
# NOTE: the recovered_probs are overwritten by this method.
129129
recovered_token_ids = _multinomial(
130130
recovered_probs,
131131
num_samples=1,
132132
k=k,
133-
generators=generators,
134-
seed_indices=seed_indices,
135-
# this arg is unused when None but torch.jit requires a list
136-
non_seed_indices=non_seed_indices or [],
133+
seeded_seqs=seeded_seqs or {},
137134
).reshape(batch_size, k)
138135

139136
return accepted, recovered_token_ids
@@ -143,7 +140,7 @@ def _get_accepted(
143140
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
144141
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
145142
draft_token_ids: torch.Tensor, # [batch_size, k]
146-
generators: List[Optional[torch.Generator]],
143+
seeded_seqs: Optional[Dict[int, torch.Generator]],
147144
) -> torch.Tensor:
148145
r"""Create bool matrix over the proposed draft tokens. If
149146
True, then a token can be accepted, else it should be
@@ -178,24 +175,26 @@ def _get_accepted(
178175
selected_target_probs = target_probs[batch_indices, probs_indicies,
179176
draft_token_ids]
180177

181-
seed_indices, non_seed_indices = self._split_batch_by_seeded(
182-
generators)
183-
184-
if len(seed_indices) == 0:
178+
if not seeded_seqs:
185179
uniform_rand = torch.rand_like(selected_target_probs)
186180
else:
187181
uniform_rand = torch.empty_like(selected_target_probs)
188182

189-
for idx in seed_indices:
190-
uniform_rand[idx, :] = torch.rand(1,
191-
k,
192-
dtype=self.probs_dtype,
193-
device=target_probs.device,
194-
generator=generators[idx])
195-
196-
if non_seed_indices:
197-
uniform_rand[non_seed_indices, :] = torch.rand(
198-
len(non_seed_indices),
183+
non_seeded_indices = []
184+
for idx in range(batch_size):
185+
generator = seeded_seqs.get(idx)
186+
if generator is None:
187+
non_seeded_indices.append(idx)
188+
else:
189+
uniform_rand[idx, :] = torch.rand(
190+
1,
191+
k,
192+
dtype=self.probs_dtype,
193+
device=target_probs.device,
194+
generator=generator)
195+
if non_seeded_indices:
196+
uniform_rand[non_seeded_indices, :] = torch.rand(
197+
len(non_seeded_indices),
199198
k,
200199
dtype=self.probs_dtype,
201200
device=target_probs.device)
@@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float:
272271
"""
273272
return torch.finfo(self.probs_dtype).tiny
274273

275-
# partition batch into indices for which a generator is provided
276-
# and indicies for which no generator is provided
277-
@staticmethod
278-
def _split_batch_by_seeded(
279-
generators: List[Optional[torch.Generator]],
280-
k: int = 1,
281-
) -> Tuple[List[int], Optional[List[int]]]:
282-
283-
if all(generator is None for generator in generators):
284-
seed_indices: List[int] = []
285-
non_seed_indices: Optional[List[int]] = None
286-
else:
287-
seed_indices, non_seed_indices = [], []
288-
for i, generator in enumerate(generators):
289-
if generator is None:
290-
non_seed_indices.extend(range(k * i, k * (i + 1)))
291-
else:
292-
seed_indices.extend(range(k * i, k * (i + 1)))
293-
294-
return seed_indices, non_seed_indices
295-
296274

297275
# torch.multinomial forces a GPU<->CPU sync.
298276
# Therefore, we use an optimized implementation instead that skips the sync.
@@ -304,9 +282,7 @@ def _multinomial(
304282
probs: torch.Tensor,
305283
num_samples: int,
306284
k: int,
307-
generators: List[Optional[torch.Generator]],
308-
seed_indices: List[int],
309-
non_seed_indices: List[int],
285+
seeded_seqs: Dict[int, torch.Generator],
310286
) -> torch.Tensor:
311287

312288
if num_samples > 1:
@@ -315,13 +291,20 @@ def _multinomial(
315291
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
316292
probs.shape[1]).contiguous().view(
317293
-1, probs.shape[1])
318-
319294
q = torch.empty_like(probs)
320-
if len(seed_indices) == 0:
295+
if not seeded_seqs:
321296
q.exponential_(1.0)
322297
else:
323-
q[non_seed_indices].exponential_(1.0)
324-
for idx in seed_indices:
325-
q[idx].exponential_(1.0, generator=generators[idx // k])
298+
non_seeded_indices: List[int] = []
299+
start = 0
300+
for idx in range(len(q) // k):
301+
end = start + k
302+
generator = seeded_seqs.get(idx)
303+
if generator is None:
304+
non_seeded_indices.extend(list(range(start, end)))
305+
else:
306+
q[start:end].exponential_(1.0, generator=generator)
307+
start = end
308+
q[non_seeded_indices].exponential_(1.0)
326309

327310
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

vllm/model_executor/layers/spec_decode_base_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import List, Optional
2+
from typing import Dict, Optional
33

44
import torch
55
import torch.jit
@@ -237,6 +237,6 @@ def forward(
237237
bonus_token_ids: torch.Tensor,
238238
draft_probs: torch.Tensor,
239239
draft_token_ids: torch.Tensor,
240-
generators: List[Optional[torch.Generator]],
240+
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
241241
) -> torch.Tensor:
242242
raise NotImplementedError

vllm/spec_decode/batch_expansion.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from vllm import SamplingParams
67
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
78
SequenceGroupMetadata, get_all_seq_ids)
89
from vllm.spec_decode.interfaces import (SpeculativeProposals,
@@ -15,6 +16,8 @@
1516
TargetSeqId = int
1617
TokenId = int
1718

19+
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
20+
1821

1922
class BatchExpansionTop1Scorer(SpeculativeScorer):
2023
"""Implements a speculative scorer that uses batch expansion to get
@@ -246,14 +249,25 @@ def _create_target_seq_group_metadata(
246249
token_ids_to_score = self._get_token_ids_to_score(
247250
proposal_token_ids[batch_index])
248251

252+
# Use simpler sampling parameters apart from for final token
253+
# (in particular don't do seeded sampling) since those sampled tokens
254+
# aren't used
255+
sampling_params = input_seq_group_metadata.sampling_params
256+
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
257+
if sampling_params.temperature else sampling_params
258+
249259
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
250-
for token_ids in token_ids_to_score:
260+
last_index = len(token_ids_to_score) - 1
261+
for i, token_ids in enumerate(token_ids_to_score):
262+
target_sampling_params = sampling_params if i == last_index \
263+
else non_bonus_sampling_params
251264
target_seq_group_metadata_list.append(
252265
self._create_single_target_seq_group_metadata(
253266
input_seq_group_metadata,
254267
input_seq_id,
255268
next(target_seq_ids_iter),
256269
token_ids,
270+
sampling_params=target_sampling_params,
257271
))
258272

259273
return target_seq_group_metadata_list
@@ -264,6 +278,7 @@ def _create_single_target_seq_group_metadata(
264278
seq_id: SeqId,
265279
target_seq_id: TargetSeqId,
266280
token_ids: List[TokenId],
281+
sampling_params: SamplingParams,
267282
) -> SequenceGroupMetadata:
268283
"""Create a single target SequenceGroupMetadata.
269284
@@ -296,7 +311,7 @@ def _create_single_target_seq_group_metadata(
296311
request_id=seq_group_metadata.request_id,
297312
is_prompt=seq_group_metadata.is_prompt,
298313
seq_data=new_seq_data_dict,
299-
sampling_params=seq_group_metadata.sampling_params,
314+
sampling_params=sampling_params,
300315
block_tables={
301316
target_seq_id: seq_group_metadata.block_tables[seq_id],
302317
},

vllm/spec_decode/medusa_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ def sampler_output(
5757
seq_lens, query_lens = self._prepare_input_tensors(
5858
seq_group_metadata_list)
5959

60+
generators = self.model_runner.get_generators(
61+
execute_model_req.finished_requests_ids)
6062
sampling_metadata = SamplingMetadata.prepare(
6163
seq_group_metadata_list, seq_lens, query_lens, self.device,
62-
self.model_runner.pin_memory)
64+
self.model_runner.pin_memory, generators)
6365

6466
model_outputs = self.model_runner.model.generate_proposals(
6567
previous_hidden_states=execute_model_req.previous_hidden_states.

vllm/spec_decode/mlp_speculator_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ def sampler_output(
3838
(input_tokens, seq_lens,
3939
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
4040

41+
generators = self.model_runner.get_generators(
42+
execute_model_req.finished_requests_ids)
4143
sampling_metadata = SamplingMetadata.prepare(
4244
seq_group_metadata_list, seq_lens, query_lens, self.device,
43-
self.model_runner.pin_memory)
45+
self.model_runner.pin_memory, generators)
4446

4547
model_outputs = self.model_runner.model.generate_proposals(
4648
input_ids=input_tokens,

vllm/spec_decode/ngram_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from vllm.spec_decode.interfaces import SpeculativeProposals
88
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
99
from vllm.spec_decode.top1_proposer import Top1Proposer
10-
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
1110

1211

13-
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
12+
class NGramWorker(NonLLMProposerWorkerBase):
1413
"""NGramWorker provides a light drafter without need for model.
1514
1615
Current NGramWorker only implements prompt lookup decoding,

0 commit comments

Comments
 (0)