@@ -116,6 +116,10 @@ class SamplingParams(
116116
117117 Args:
118118 n: Number of output sequences to return for the given prompt.
119+ best_of: Number of output sequences that are generated from the prompt.
120+ From these `best_of` sequences, the top `n` sequences are returned.
121+ `best_of` must be greater than or equal to `n`. By default,
122+ `best_of` is set to `n`. Warning, this is only supported in V0.
119123 presence_penalty: Float that penalizes new tokens based on whether they
120124 appear in the generated text so far. Values > 0 encourage the model
121125 to use new tokens, while values < 0 encourage the model to repeat
@@ -183,6 +187,7 @@ class SamplingParams(
183187 """
184188
185189 n : int = 1
190+ best_of : Optional [int ] = None
186191 _real_n : Optional [int ] = None
187192 presence_penalty : float = 0.0
188193 frequency_penalty : float = 0.0
@@ -226,6 +231,7 @@ class SamplingParams(
226231 @staticmethod
227232 def from_optional (
228233 n : Optional [int ] = 1 ,
234+ best_of : Optional [int ] = None ,
229235 presence_penalty : Optional [float ] = 0.0 ,
230236 frequency_penalty : Optional [float ] = 0.0 ,
231237 repetition_penalty : Optional [float ] = 1.0 ,
@@ -264,6 +270,7 @@ def from_optional(
264270
265271 return SamplingParams (
266272 n = 1 if n is None else n ,
273+ best_of = best_of ,
267274 presence_penalty = 0.0
268275 if presence_penalty is None else presence_penalty ,
269276 frequency_penalty = 0.0
@@ -296,6 +303,20 @@ def from_optional(
296303 )
297304
298305 def __post_init__ (self ) -> None :
306+ # how we deal with `best_of``:
307+ # if `best_of`` is not set, we default to `n`;
308+ # if `best_of`` is set, we set `n`` to `best_of`,
309+ # and set `_real_n`` to the original `n`.
310+ # when we return the result, we will check
311+ # if we need to return `n` or `_real_n` results
312+ if self .best_of :
313+ if self .best_of < self .n :
314+ raise ValueError (
315+ f"best_of must be greater than or equal to n, "
316+ f"got n={ self .n } and best_of={ self .best_of } ." )
317+ if not self ._real_n :
318+ self ._real_n = self .n
319+ self .n = self .best_of
299320
300321 if 0 < self .temperature < _MAX_TEMP :
301322 logger .warning (
@@ -402,6 +423,9 @@ def _verify_args(self) -> None:
402423 raise ValueError (
403424 "stop strings are only supported when detokenize is True. "
404425 "Set detokenize=True to use stop." )
426+ if self .best_of != self ._real_n and self .output_kind == (
427+ RequestOutputKind .DELTA ):
428+ raise ValueError ("best_of must equal n to use output_kind=DELTA" )
405429
406430 def _verify_greedy_sampling (self ) -> None :
407431 if self .n > 1 :
0 commit comments