Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 137 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import random
from array import array
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
Expand All @@ -12,7 +13,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import Counter, is_pin_memory_available


Expand Down Expand Up @@ -754,3 +756,137 @@ def test_sampler_include_gpu_probs_tensor(device: str):
assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_dry(device: str):
vocab_size = 8

def test_sampling_params(sampling_params: List[SamplingParams]):
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(2):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0:
SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 1, 2]))
},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())

fake_logits = torch.full((2, vocab_size),
1e-2,
device=device,
dtype=torch.float16)
fake_logits[:, 3] = 1.0

sampler = MockLogitsSampler(fake_logits)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

generated_tokens = []
for output in sampler_output:
generated_tokens.append(output.samples[0].output_token)

return generated_tokens

# Test case 1: DRY disabled (multiplier = 0)
sampling_params_no_dry = SamplingParams(
temperature=0.0,
dry_multiplier=0.0,
)

# Test case 2: DRY enabled with full range
sampling_params_full_dry = SamplingParams(
temperature=0.0,
dry_multiplier=1.0,
dry_allowed_length=2,
dry_base=2.0,
dry_range=0,
)

# Test case 3: DRY enabled with limited range
sampling_params_limited_dry = SamplingParams(
temperature=0.0,
dry_multiplier=1.0,
dry_allowed_length=2,
dry_base=2.0,
dry_range=3,
)

tokens1 = test_sampling_params(
[sampling_params_no_dry, sampling_params_full_dry])

assert tokens1[0] == 3, "Without DRY, should choose highest logit token"
assert tokens1[
1] != 3, "With full-range DRY, should avoid repeating pattern" # noqa: E501

tokens2 = test_sampling_params(
[sampling_params_full_dry, sampling_params_limited_dry])

assert tokens2[0] != 3, "Full-range DRY should detect full pattern"
assert tokens2[
1] == 3, "Limited-range DRY should only consider recent tokens" # noqa: E501

tokens3 = test_sampling_params(
[sampling_params_full_dry, sampling_params_limited_dry])
assert tokens2 == tokens3, "DRY sampling should be deterministic"


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_dry_sequence_breakers(device: str):
"""Test that DRY respects sequence breakers."""
vocab_size = 8

# 7 is a sequence breaker
input_sequence = [1, 2, 7, 1, 2]

seq_group_metadata = SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, input_sequence))
},
sampling_params=SamplingParams(
temperature=0.0,
dry_multiplier=1.0,
dry_allowed_length=2,
dry_base=2.0,
dry_range=0,
dry_sequence_breaker_ids=[7],
),
block_tables={0: [1]},
)

sampling_metadata = SamplingMetadata.prepare(
[seq_group_metadata],
seq_lens=[len(input_sequence)],
query_lens=[len(input_sequence)],
device=device,
pin_memory=is_pin_memory_available())

fake_logits = torch.full((1, vocab_size),
1e-2,
device=device,
dtype=torch.float16)
fake_logits[0, 3] = 1.0

sampler = MockLogitsSampler(fake_logits)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

assert sampler_output[0].samples[0].output_token == 3, \
"DRY should not detect patterns across sequence breakers"
52 changes: 51 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
model_validator)
from typing_extensions import Annotated

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
Expand All @@ -15,6 +16,7 @@
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid, resolve_obj_by_qualname

logger = init_logger(__name__)
Expand Down Expand Up @@ -235,6 +237,14 @@ class ChatCompletionRequest(OpenAIBaseModel):
top_k: Optional[int] = None
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
dry_multiplier: Optional[float] = None
dry_base: Optional[float] = None
dry_allowed_length: Optional[int] = None
dry_sequence_breakers: Optional[Union[List[str], List[int]]] = Field(
default=["\n", ":", "\"", "*"])
dry_range: Optional[int] = Field(default=0,
validation_alias=AliasChoices(
"dry_range", "dry_penalty_last_n"))
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
Expand Down Expand Up @@ -393,6 +403,7 @@ def to_beam_search_params(

def to_sampling_params(
self,
tokenizer: AnyTokenizer,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
Expand Down Expand Up @@ -437,6 +448,16 @@ def to_sampling_params(
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"

dry_sequence_breaker_ids = []
if isinstance(self.dry_sequence_breakers,
list) and self.dry_sequence_breakers:
if isinstance(self.dry_sequence_breakers[0], str):
for s in self.dry_sequence_breakers:
token_id = tokenizer.encode(f'a{s}')[-1]
dry_sequence_breaker_ids.append(token_id)
elif isinstance(self.dry_sequence_breakers[0], int):
dry_sequence_breaker_ids = list(self.dry_sequence_breakers)

guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
regex=self.guided_regex,
Expand All @@ -452,6 +473,11 @@ def to_sampling_params(
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
dry_multiplier=self.dry_multiplier,
dry_base=self.dry_base,
dry_allowed_length=self.dry_allowed_length,
dry_sequence_breaker_ids=dry_sequence_breaker_ids,
dry_range=self.dry_range,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand Down Expand Up @@ -641,6 +667,14 @@ class CompletionRequest(OpenAIBaseModel):
top_k: Optional[int] = None
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
dry_multiplier: Optional[float] = None
dry_base: Optional[float] = None
dry_allowed_length: Optional[int] = None
dry_sequence_breakers: Optional[Union[List[str], List[int]]] = Field(
default=["\n", ":", "\"", "*"])
dry_range: Optional[int] = Field(default=0,
validation_alias=AliasChoices(
"dry_range", "dry_penalty_last_n"))
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
Expand Down Expand Up @@ -752,6 +786,7 @@ def to_beam_search_params(

def to_sampling_params(
self,
tokenizer: AnyTokenizer,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
Expand Down Expand Up @@ -784,6 +819,16 @@ def to_sampling_params(
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs

dry_sequence_breaker_ids = []
if isinstance(self.dry_sequence_breakers,
list) and self.dry_sequence_breakers:
if isinstance(self.dry_sequence_breakers[0], str):
for s in self.dry_sequence_breakers:
token_id = tokenizer.encode(f'a{s}')[-1]
dry_sequence_breaker_ids.append(token_id)
elif isinstance(self.dry_sequence_breakers[0], int):
dry_sequence_breaker_ids = list(self.dry_sequence_breakers)

echo_without_generation = self.echo and self.max_tokens == 0

guided_json_object = None
Expand All @@ -806,6 +851,11 @@ def to_sampling_params(
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
dry_multiplier=self.dry_multiplier,
dry_base=self.dry_base,
dry_allowed_length=self.dry_allowed_length,
dry_sequence_breaker_ids=dry_sequence_breaker_ids,
dry_range=self.dry_range,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def create_chat_completion(
default_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
tokenizer, default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params)

Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def create_completion(
default_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
tokenizer, default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params)

Expand Down
Loading