Skip to content

Commit 5c538c3

Browse files
authored
[V1][Bugfix][Spec Decode] Fix incorrect outputs in V1 speculative decoding due to batch indexing (#14645)
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent e22ee1e commit 5c538c3

File tree

2 files changed

+50
-15
lines changed

2 files changed

+50
-15
lines changed
Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,48 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import random
3+
24
import pytest
35

46
from vllm import LLM, SamplingParams
57

68

79
@pytest.fixture
810
def test_prompts():
9-
return [
10-
"Can you repeat the sentence ten times, this is a sentence.",
11-
"Can you repeat the sentence ten times, this is a test.",
12-
]
11+
prompt_types = ["repeat", "sentence"]
12+
num_prompts = 100
13+
prompts = []
14+
15+
random.seed(0)
16+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
17+
18+
# Generate a mixed batch of prompts, some of which can be easily
19+
# predicted by n-gram matching and some which likely cannot.
20+
for kind in random_prompt_type_choices:
21+
word_choices = ["test", "temp", "hello", "where"]
22+
word = random.choice(word_choices)
23+
if kind == "repeat":
24+
prompt = f"""
25+
please repeat the word '{word}' 10 times.
26+
give no other output than the word at least ten times in a row,
27+
in lowercase with spaces between each word and without quotes.
28+
"""
29+
elif kind == "sentence":
30+
prompt = f"""
31+
please give a ten-word sentence that
32+
uses the word {word} at least once.
33+
give no other output than that simple sentence without quotes.
34+
"""
35+
else:
36+
raise ValueError(f"Unknown prompt type: {kind}")
37+
prompts.append([{"role": "user", "content": prompt}])
38+
39+
return prompts
1340

1441

1542
@pytest.fixture
1643
def sampling_config():
1744
# Only support greedy for now
18-
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)
45+
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
1946

2047

2148
@pytest.fixture
@@ -32,18 +59,28 @@ def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
3259
with monkeypatch.context() as m:
3360
m.setenv("VLLM_USE_V1", "1")
3461

35-
ref_llm = LLM(model=model_name)
36-
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
62+
ref_llm = LLM(model=model_name, max_model_len=1024)
63+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
3764
del ref_llm
3865

3966
spec_llm = LLM(model=model_name,
4067
speculative_model='[ngram]',
4168
ngram_prompt_lookup_max=5,
4269
ngram_prompt_lookup_min=3,
43-
num_speculative_tokens=3)
44-
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
70+
num_speculative_tokens=3,
71+
max_model_len=1024)
72+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
73+
matches = 0
74+
misses = 0
4575
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
46-
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
47-
(f"ref_output: {ref_output.outputs[0].text},"
48-
f"spec_output: {spec_output.outputs[0].text}")
76+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
77+
matches += 1
78+
else:
79+
misses += 1
80+
print(f"ref_output: {ref_output.outputs[0].text}")
81+
print(f"spec_output: {spec_output.outputs[0].text}")
82+
83+
# Heuristic: expect at least 70% of the prompts to match exactly
84+
# Upon failure, inspect the outputs to check for inaccuracy.
85+
assert matches > int(0.7 * len(ref_outputs))
4986
del spec_llm

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,11 +1015,9 @@ def execute_model(
10151015
else:
10161016
target_probs = self.model.sampler.compute_probs(
10171017
logits, sampling_metadata)
1018-
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
1019-
)
10201018
draft_token_ids = [
10211019
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
1022-
for req_id in scheduled_request_ids
1020+
for req_id in self.input_batch.req_ids
10231021
]
10241022
sampler_output = self.rejection_sampler(draft_token_ids,
10251023
target_probs,

0 commit comments

Comments
 (0)