Skip to content

Commit aac2fb8

Browse files
southfreebirdSergei Skvortsovnjhill
authored andcommitted
[V1] Logit processors for rejection sampler (vllm-project#19482)
Signed-off-by: southfreebird <[email protected]> Signed-off-by: Sergei Skvortsov <[email protected]> Signed-off-by: Sergei Skvortsov <[email protected]> Co-authored-by: Sergei Skvortsov <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 34abdd4 commit aac2fb8

File tree

12 files changed

+468
-89
lines changed

12 files changed

+468
-89
lines changed

tests/v1/logits_processors/test_custom_offline.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import random
44
import sys
5-
from typing import Union
5+
from typing import Any, Union
66

77
import pytest
88

@@ -25,6 +25,7 @@
2525
from vllm import LLM, SamplingParams
2626
from vllm.v1.sample.logits_processor import (
2727
STR_POOLING_REJECTS_LOGITSPROCS,
28+
STR_SPEC_DEC_REJECTS_LOGITSPROCS,
2829
LogitsProcessor,
2930
)
3031

@@ -205,6 +206,7 @@ def test_custom_logitsprocs_req(monkeypatch):
205206

206207

207208
@create_new_process_for_each_test()
209+
@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"])
208210
@pytest.mark.parametrize(
209211
"logitproc_source",
210212
[
@@ -213,11 +215,12 @@ def test_custom_logitsprocs_req(monkeypatch):
213215
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
214216
],
215217
)
216-
def test_pooling_rejects_custom_logitsprocs(
217-
monkeypatch, logitproc_source: CustomLogitprocSource
218+
def test_rejects_custom_logitsprocs(
219+
monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource
218220
):
219221
"""Validate that vLLM engine initialization properly rejects custom
220-
logitsprocs when the model is a pooling model.
222+
logitsprocs when the model is a pooling model or speculative decoding
223+
enabled.
221224
222225
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
223226
logitproc is actually loaded.
@@ -241,8 +244,32 @@ def test_pooling_rejects_custom_logitsprocs(
241244
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
242245
random.seed(40)
243246

247+
test_params: dict[str, dict[str, Any]] = {
248+
"pooling": {
249+
"runner": "pooling",
250+
"model": POOLING_MODEL_NAME,
251+
"error_message": STR_POOLING_REJECTS_LOGITSPROCS,
252+
"speculative_config": None,
253+
},
254+
"spec_dec": {
255+
"runner": "auto",
256+
"model": MODEL_NAME,
257+
"error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS,
258+
"speculative_config": {"model": "ngram", "num_speculative_tokens": 1},
259+
},
260+
}
261+
262+
config = test_params[model_scenario]
263+
264+
llm_kwargs: dict[str, Any] = {
265+
"runner": config["runner"],
266+
"model": config["model"],
267+
"gpu_memory_utilization": 0.1,
268+
"speculative_config": config["speculative_config"],
269+
}
270+
244271
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
245-
# Scenario: vLLM loads a pooling model and ignores a logitproc that is
272+
# Scenario: vLLM loads a model and ignores a logitproc that is
246273
# available at a preconfigured entrypoint
247274

248275
# Patch in dummy logitproc entrypoint
@@ -254,30 +281,20 @@ def test_pooling_rejects_custom_logitsprocs(
254281
# although they should ignore the entrypoint patch anyway
255282
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
256283

257-
llm = LLM(
258-
runner="pooling",
259-
model=POOLING_MODEL_NAME,
260-
gpu_memory_utilization=0.1,
261-
)
284+
llm = LLM(**llm_kwargs)
262285
# Require that no logitsprocs have been loaded
263286
worker = llm.llm_engine.model_executor.driver_worker.worker
264287
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0
265288
return
266289

267-
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
268290
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
269291
# Scenario: load logitproc based on fully-qualified class name (FQCN)
270-
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
292+
llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
271293
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
272294
# Scenario: load logitproc from provided class object
273-
kwargs["logits_processors"] = [DummyLogitsProcessor]
295+
llm_kwargs["logits_processors"] = [DummyLogitsProcessor]
274296

275-
with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS):
276-
# Require that loading a pooling model alongside the logitproc raises
297+
with pytest.raises(ValueError, match=config["error_message"]):
298+
# Require that loading a model alongside the logitproc raises
277299
# the appropriate exception.
278-
LLM(
279-
runner="pooling",
280-
model=POOLING_MODEL_NAME,
281-
gpu_memory_utilization=0.1,
282-
**kwargs,
283-
)
300+
LLM(**llm_kwargs)

tests/v1/sample/test_rejection_sampler.py

Lines changed: 171 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.nn.functional as F
88

9+
from tests.v1.sample.utils import create_allowed_token_ids
910
from vllm.platforms import current_platform
1011
from vllm.v1.sample.logits_processor import LogitsProcessors
1112
from vllm.v1.sample.metadata import SamplingMetadata
@@ -21,7 +22,9 @@ def rejection_sampler():
2122

2223

2324
def create_logits_tensor(
24-
output_token_ids: list[list[int]], vocab_size: int = 100
25+
output_token_ids: list[list[int]],
26+
vocab_size: int = 100,
27+
token_idx_to_override: Optional[int] = None,
2528
) -> torch.Tensor:
2629
"""Helper function to create logits tensor that
2730
will produce desired token ids on argmax"""
@@ -33,15 +36,25 @@ def create_logits_tensor(
3336
for j, token_id in enumerate(tokens):
3437
logits[start_loc + j, token_id] = 100.0
3538
start_loc += len(tokens)
39+
if token_idx_to_override:
40+
logits[:, token_idx_to_override] = 99.0
3641
return logits
3742

3843

3944
def create_sampling_metadata(
4045
all_greedy: bool,
46+
output_token_ids: Optional[list[list[int]]] = None,
47+
prompt_token_ids: Optional[torch.Tensor] = None,
48+
spec_token_ids: Optional[torch.Tensor] = None,
4149
temperature: Optional[torch.Tensor] = None,
4250
top_k: Optional[torch.Tensor] = None,
4351
top_p: Optional[torch.Tensor] = None,
4452
generators: Optional[dict[int, Any]] = None,
53+
frequency_penalties: Optional[list[float]] = None,
54+
presence_penalties: Optional[list[float]] = None,
55+
repetition_penalties: Optional[list[float]] = None,
56+
bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None,
57+
allowed_token_ids_mask: Optional[torch.Tensor] = None,
4558
) -> SamplingMetadata:
4659
"""Create a v1 sampling metadata object with all_greedy set
4760
to the given value. Either all greedy or all random sampling
@@ -53,6 +66,21 @@ def create_sampling_metadata(
5366
else:
5467
assert temperature is not None
5568

69+
if any([frequency_penalties, presence_penalties, repetition_penalties]):
70+
no_penalties = False
71+
72+
assert output_token_ids
73+
assert len(output_token_ids) > 0
74+
75+
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
76+
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
77+
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
78+
else:
79+
no_penalties = True
80+
frequency_penalties = torch.tensor([])
81+
presence_penalties = torch.tensor([])
82+
repetition_penalties = torch.tensor([])
83+
5684
return SamplingMetadata(
5785
temperature=temperature,
5886
all_greedy=all_greedy,
@@ -61,14 +89,15 @@ def create_sampling_metadata(
6189
top_k=top_k,
6290
generators=generators,
6391
max_num_logprobs=0,
64-
no_penalties=False,
65-
prompt_token_ids=None,
66-
frequency_penalties=torch.tensor([]),
67-
presence_penalties=torch.tensor([]),
68-
repetition_penalties=torch.tensor([]),
69-
output_token_ids=[],
70-
allowed_token_ids_mask=None,
71-
bad_words_token_ids={},
92+
no_penalties=no_penalties,
93+
prompt_token_ids=prompt_token_ids,
94+
frequency_penalties=frequency_penalties,
95+
presence_penalties=presence_penalties,
96+
repetition_penalties=repetition_penalties,
97+
output_token_ids=[] if output_token_ids is None else output_token_ids,
98+
spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
99+
allowed_token_ids_mask=allowed_token_ids_mask,
100+
bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
72101
logitsprocs=LogitsProcessors(),
73102
)
74103

@@ -611,3 +640,136 @@ def test_top_p(rejection_sampler, top_p):
611640
unmasked_indices=top_p_indices,
612641
sampling_metadata=sampling_metadata,
613642
)
643+
644+
645+
########################### Tests for Logit Processors ###################
646+
def test_frequency_penalties(rejection_sampler):
647+
"""Test rejection sampling with frequency penalties"""
648+
spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
649+
output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]] # 1, 7 and 1 are the bonus tokens
650+
651+
num_requsts = len(spec_tokens)
652+
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
653+
metadata = create_sampling_metadata(
654+
all_greedy=True,
655+
output_token_ids=[[2], [3], [4]],
656+
spec_token_ids=spec_tokens,
657+
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
658+
frequency_penalties=[1.5, 1.5, 0.7],
659+
presence_penalties=[0.0] * num_requsts,
660+
repetition_penalties=[1.0] * num_requsts,
661+
)
662+
bonus_token_tensor = torch.tensor(
663+
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
664+
)
665+
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
666+
spec_tokens, device=logits.device
667+
)
668+
output = rejection_sampler(
669+
spec_decode_metadata,
670+
draft_probs=None,
671+
target_logits=logits,
672+
bonus_token_ids=bonus_token_tensor,
673+
sampling_metadata=metadata,
674+
)
675+
expected = torch.tensor(
676+
[[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
677+
dtype=torch.int,
678+
device=logits.device,
679+
)
680+
assert torch.equal(output, expected)
681+
682+
683+
def test_bad_words(rejection_sampler):
684+
"""Test rejection sampling with bad words constraints"""
685+
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
686+
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
687+
688+
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
689+
metadata = create_sampling_metadata(
690+
all_greedy=True,
691+
output_token_ids=[[2], [3], [4]],
692+
spec_token_ids=spec_tokens,
693+
bad_words_token_ids={
694+
0: [
695+
[
696+
2,
697+
]
698+
],
699+
1: [
700+
[
701+
2,
702+
]
703+
],
704+
# Do not apply bad words to the last request
705+
},
706+
)
707+
bonus_token_tensor = torch.tensor(
708+
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
709+
)
710+
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
711+
spec_tokens, device=logits.device
712+
)
713+
output = rejection_sampler(
714+
spec_decode_metadata,
715+
draft_probs=None,
716+
target_logits=logits,
717+
bonus_token_ids=bonus_token_tensor,
718+
sampling_metadata=metadata,
719+
)
720+
721+
expected = torch.tensor(
722+
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
723+
dtype=torch.int,
724+
device=logits.device,
725+
)
726+
assert torch.equal(output, expected)
727+
728+
729+
def test_allowed_token_ids(rejection_sampler):
730+
"""Test rejection sampling with allowed token ids"""
731+
spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
732+
output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
733+
# Not allowed tokens:
734+
# 0: 0-4
735+
# 1: 1-5
736+
# 2: 2-6
737+
num_allowed_token_ids = 5
738+
739+
# Use the token 15 as the sampler choose if a token rejected
740+
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
741+
742+
batch_size = len(output_tokens)
743+
_, vocab_size = logits.size()
744+
mask = create_allowed_token_ids(
745+
batch_size=batch_size,
746+
vocab_size=vocab_size,
747+
num_allowed_token_ids=num_allowed_token_ids,
748+
device=logits.device,
749+
)
750+
metadata = create_sampling_metadata(
751+
all_greedy=True,
752+
output_token_ids=[[], [], []],
753+
spec_token_ids=spec_tokens,
754+
allowed_token_ids_mask=mask,
755+
)
756+
bonus_token_tensor = torch.tensor(
757+
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
758+
)
759+
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
760+
spec_tokens, device=logits.device
761+
)
762+
output = rejection_sampler(
763+
spec_decode_metadata,
764+
draft_probs=None,
765+
target_logits=logits,
766+
bonus_token_ids=bonus_token_tensor,
767+
sampling_metadata=metadata,
768+
)
769+
770+
expected = torch.tensor(
771+
[[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
772+
dtype=torch.int,
773+
device=logits.device,
774+
)
775+
assert torch.equal(output, expected)

0 commit comments

Comments
 (0)