Skip to content

Commit 5ce75b2

Browse files
hmellorlulmer
authored andcommitted
Reinstate best_of for V0 (vllm-project#14356)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 631aa3e commit 5ce75b2

File tree

6 files changed

+50
-3
lines changed

6 files changed

+50
-3
lines changed

tests/v1/sample/test_sampling_params_e2e.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def test_n_gt_1(model):
2525
assert len(outputs[0].outputs) == 3
2626

2727

28+
def test_best_of(model):
29+
"""Raise a ValueError since best_of is deprecated."""
30+
31+
params = SamplingParams(n=2, best_of=3)
32+
with pytest.raises(ValueError):
33+
_ = model.generate(PROMPT, params)
34+
35+
2836
def test_penalties(model):
2937
"""Check that we do not get errors if applied."""
3038

vllm/entrypoints/llm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ class LLM:
9797
throughput. However, if the value is too high, it may cause out-of-
9898
memory (OOM) errors.
9999
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
100-
Too small values may cause out-of-memory (OOM) errors.
100+
This can be used for temporarily storing the states of the requests
101+
when their `best_of` sampling parameters are larger than 1. If all
102+
requests will have `best_of=1`, you can safely set this to 0.
103+
Noting that `best_of` is only supported in V0. Otherwise, too small
104+
values may cause out-of-memory (OOM) errors.
101105
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
102106
the model weights. This virtually increases the GPU memory space
103107
you can use to hold the model weights, at the cost of CPU-GPU data

vllm/entrypoints/openai/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
242242
user: Optional[str] = None
243243

244244
# doc: begin-chat-completion-sampling-params
245+
best_of: Optional[int] = None
245246
use_beam_search: bool = False
246247
top_k: Optional[int] = None
247248
min_p: Optional[float] = None
@@ -478,6 +479,7 @@ def to_sampling_params(
478479

479480
return SamplingParams.from_optional(
480481
n=self.n,
482+
best_of=self.best_of,
481483
presence_penalty=self.presence_penalty,
482484
frequency_penalty=self.frequency_penalty,
483485
repetition_penalty=repetition_penalty,
@@ -648,6 +650,7 @@ class CompletionRequest(OpenAIBaseModel):
648650
# https://platform.openai.com/docs/api-reference/completions/create
649651
model: Optional[str] = None
650652
prompt: Union[list[int], list[list[int]], str, list[str]]
653+
best_of: Optional[int] = None
651654
echo: Optional[bool] = False
652655
frequency_penalty: Optional[float] = 0.0
653656
logit_bias: Optional[dict[str, float]] = None
@@ -845,6 +848,7 @@ def to_sampling_params(
845848

846849
return SamplingParams.from_optional(
847850
n=self.n,
851+
best_of=self.best_of,
848852
presence_penalty=self.presence_penalty,
849853
frequency_penalty=self.frequency_penalty,
850854
repetition_penalty=repetition_penalty,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,12 @@ async def create_completion(
168168
model_name = self._get_model_name(request.model, lora_request)
169169
num_prompts = len(engine_prompts)
170170

171-
# We do not stream the results when use beam search.
172-
stream = (request.stream and not request.use_beam_search)
171+
# Similar to the OpenAI API, when n != best_of, we do not stream the
172+
# results. Noting that best_of is only supported in V0. In addition,
173+
# we do not stream the results when use beam search.
174+
stream = (request.stream
175+
and (request.best_of is None or request.n == request.best_of)
176+
and not request.use_beam_search)
173177

174178
# Streaming response
175179
if stream:

vllm/sampling_params.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class SamplingParams(
116116
117117
Args:
118118
n: Number of output sequences to return for the given prompt.
119+
best_of: Number of output sequences that are generated from the prompt.
120+
From these `best_of` sequences, the top `n` sequences are returned.
121+
`best_of` must be greater than or equal to `n`. By default,
122+
`best_of` is set to `n`. Warning, this is only supported in V0.
119123
presence_penalty: Float that penalizes new tokens based on whether they
120124
appear in the generated text so far. Values > 0 encourage the model
121125
to use new tokens, while values < 0 encourage the model to repeat
@@ -183,6 +187,7 @@ class SamplingParams(
183187
"""
184188

185189
n: int = 1
190+
best_of: Optional[int] = None
186191
_real_n: Optional[int] = None
187192
presence_penalty: float = 0.0
188193
frequency_penalty: float = 0.0
@@ -226,6 +231,7 @@ class SamplingParams(
226231
@staticmethod
227232
def from_optional(
228233
n: Optional[int] = 1,
234+
best_of: Optional[int] = None,
229235
presence_penalty: Optional[float] = 0.0,
230236
frequency_penalty: Optional[float] = 0.0,
231237
repetition_penalty: Optional[float] = 1.0,
@@ -264,6 +270,7 @@ def from_optional(
264270

265271
return SamplingParams(
266272
n=1 if n is None else n,
273+
best_of=best_of,
267274
presence_penalty=0.0
268275
if presence_penalty is None else presence_penalty,
269276
frequency_penalty=0.0
@@ -296,6 +303,20 @@ def from_optional(
296303
)
297304

298305
def __post_init__(self) -> None:
306+
# how we deal with `best_of``:
307+
# if `best_of`` is not set, we default to `n`;
308+
# if `best_of`` is set, we set `n`` to `best_of`,
309+
# and set `_real_n`` to the original `n`.
310+
# when we return the result, we will check
311+
# if we need to return `n` or `_real_n` results
312+
if self.best_of:
313+
if self.best_of < self.n:
314+
raise ValueError(
315+
f"best_of must be greater than or equal to n, "
316+
f"got n={self.n} and best_of={self.best_of}.")
317+
if not self._real_n:
318+
self._real_n = self.n
319+
self.n = self.best_of
299320

300321
if 0 < self.temperature < _MAX_TEMP:
301322
logger.warning(
@@ -402,6 +423,9 @@ def _verify_args(self) -> None:
402423
raise ValueError(
403424
"stop strings are only supported when detokenize is True. "
404425
"Set detokenize=True to use stop.")
426+
if self.best_of != self._real_n and self.output_kind == (
427+
RequestOutputKind.DELTA):
428+
raise ValueError("best_of must equal n to use output_kind=DELTA")
405429

406430
def _verify_greedy_sampling(self) -> None:
407431
if self.n > 1:

vllm/v1/engine/processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def _validate_supported_sampling_params(
9393
self,
9494
params: SamplingParams,
9595
) -> None:
96+
# Best of not yet supported.
97+
if params.best_of is not None and params.best_of > 1:
98+
raise ValueError("VLLM V1 does not yet support best_of.")
9699
# Bad words not yet supported.
97100
if params.bad_words:
98101
raise ValueError("VLLM V1 does not yet support bad_words.")

0 commit comments

Comments
 (0)