Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions tests/v1/e2e/test_ngram_spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
import random

import pytest

from vllm import LLM, SamplingParams


@pytest.fixture
def test_prompts():
return [
"Can you repeat the sentence ten times, this is a sentence.",
"Can you repeat the sentence ten times, this is a test.",
]
prompt_types = ["repeat", "sentence"]
num_prompts = 100
prompts = []

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)

# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])

return prompts


@pytest.fixture
def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)


@pytest.fixture
Expand All @@ -32,18 +59,28 @@ def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name)
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm

spec_llm = LLM(model=model_name,
speculative_model='[ngram]',
ngram_prompt_lookup_max=5,
ngram_prompt_lookup_min=3,
num_speculative_tokens=3)
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
num_speculative_tokens=3,
max_model_len=1024)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
(f"ref_output: {ref_output.outputs[0].text},"
f"spec_output: {spec_output.outputs[0].text}")
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
4 changes: 1 addition & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,11 +1015,9 @@ def execute_model(
else:
target_probs = self.model.sampler.compute_probs(
logits, sampling_metadata)
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
)
draft_token_ids = [
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
for req_id in scheduled_request_ids
for req_id in self.input_batch.req_ids
]
sampler_output = self.rejection_sampler(draft_token_ids,
target_probs,
Expand Down