Skip to content

Commit 9c6aeba

Browse files
authored
Document and validate typical_p in generation (#19128)
* Document and validate typical_p in generation
1 parent de359c4 commit 9c6aeba

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/transformers/generation_logits_process.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,19 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
236236

237237

238238
class TypicalLogitsWarper(LogitsWarper):
239+
r"""
240+
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
241+
Generation](https://arxiv.org/abs/2202.00666) for more information.
242+
243+
Args:
244+
mass (`float`):
245+
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
246+
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
247+
All filtered values will be set to this float value.
248+
min_tokens_to_keep (`int`, *optional*, defaults to 1):
249+
Minimum number of tokens that cannot be filtered.
250+
"""
251+
239252
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
240253
mass = float(mass)
241254
if not (mass > 0 and mass < 1):

src/transformers/generation_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,9 @@ def generate(
14861486
if stopping_criteria.max_length is None:
14871487
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
14881488

1489+
if typical_p is not None:
1490+
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
1491+
14891492
# 10. prepare beam search scorer
14901493
beam_scorer = BeamSearchScorer(
14911494
batch_size=batch_size,

0 commit comments

Comments
 (0)