|
12 | 12 |
|
13 | 13 | global_scores = None |
14 | 14 |
|
| 15 | +class MinPLogitsWarper(LogitsWarper): |
| 16 | + def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
| 17 | + if min_p < 0 or min_p > 1.0: |
| 18 | + raise ValueError(f"`min_p` has to be a float >= 0 and <= 1, but is {min_p}") |
| 19 | + self.min_p = min_p |
| 20 | + self.filter_value = filter_value |
| 21 | + self.min_tokens_to_keep = min_tokens_to_keep |
| 22 | + |
| 23 | + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| 24 | + # Convert logits to probabilities |
| 25 | + probs = torch.softmax(scores, dim=-1) |
| 26 | + # Get the probability of the top token for each sequence in the batch |
| 27 | + top_probs, _ = probs.max(dim=-1, keepdim=True) |
| 28 | + # Calculate the actual min_p threshold by scaling min_p with the top token's probability |
| 29 | + scaled_min_p = self.min_p * top_probs |
| 30 | + # Create a mask for tokens that have a probability less than the scaled min_p |
| 31 | + tokens_to_remove = probs < scaled_min_p |
| 32 | + |
| 33 | + sorted_indices = torch.argsort(scores, descending=True, dim=-1) |
| 34 | + sorted_indices_to_remove = torch.gather(tokens_to_remove, dim=-1, index=sorted_indices) |
| 35 | + |
| 36 | + if self.min_tokens_to_keep > 1: |
| 37 | + # Keep at least min_tokens_to_keep |
| 38 | + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = False |
| 39 | + |
| 40 | + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| 41 | + scores = scores.masked_fill(indices_to_remove, self.filter_value) |
| 42 | + return scores |
15 | 43 |
|
16 | 44 | class TailFreeLogitsWarper(LogitsWarper): |
17 | 45 | def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): |
@@ -190,6 +218,8 @@ def get_logits_warper_patch(self, generation_config): |
190 | 218 | warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep)) |
191 | 219 | if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0: |
192 | 220 | warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep)) |
| 221 | + if generation_config.min_p is not None and 0.0 <= generation_config.min_p <= 1.0: |
| 222 | + warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) |
193 | 223 |
|
194 | 224 | if warpers and isinstance(warpers[-1], LogitNormalization): |
195 | 225 | warpers = warpers[:-1] + warpers_to_add + [warpers[-1]] |
@@ -223,6 +253,7 @@ def get_logits_processor_patch(self, **kwargs): |
223 | 253 | def generation_config_init_patch(self, **kwargs): |
224 | 254 | self.__init___old(**kwargs) |
225 | 255 | self.tfs = kwargs.pop("tfs", 1.0) |
| 256 | + self.min_p = kwargs.pop("min_p", 0.0) |
226 | 257 | self.top_a = kwargs.pop("top_a", 0.0) |
227 | 258 | self.mirostat_mode = kwargs.pop("mirostat_mode", 0) |
228 | 259 | self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) |
|
0 commit comments