Skip to content

Commit 367e5e6

Browse files
authored
Implement Min P as a sampler option in HF loaders (#4449)
1 parent fcb7017 commit 367e5e6

7 files changed

Lines changed: 43 additions & 1 deletion

File tree

extensions/api/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def build_parameters(body, chat=False):
2626
'do_sample': bool(body.get('do_sample', True)),
2727
'temperature': float(body.get('temperature', 0.5)),
2828
'top_p': float(body.get('top_p', 1)),
29+
'min_p': float(body.get('min_p', 1)),
2930
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
3031
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
3132
'eta_cutoff': float(body.get('eta_cutoff', 0)),

modules/loaders.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
'Transformers': {
150150
'temperature',
151151
'top_p',
152+
'min_p',
152153
'top_k',
153154
'typical_p',
154155
'epsilon_cutoff',
@@ -184,6 +185,7 @@
184185
'ExLlama_HF': {
185186
'temperature',
186187
'top_p',
188+
'min_p',
187189
'top_k',
188190
'typical_p',
189191
'epsilon_cutoff',
@@ -244,6 +246,7 @@
244246
'ExLlamav2_HF': {
245247
'temperature',
246248
'top_p',
249+
'min_p',
247250
'top_k',
248251
'typical_p',
249252
'epsilon_cutoff',
@@ -275,6 +278,7 @@
275278
'AutoGPTQ': {
276279
'temperature',
277280
'top_p',
281+
'min_p',
278282
'top_k',
279283
'typical_p',
280284
'epsilon_cutoff',
@@ -310,6 +314,7 @@
310314
'GPTQ-for-LLaMa': {
311315
'temperature',
312316
'top_p',
317+
'min_p',
313318
'top_k',
314319
'typical_p',
315320
'epsilon_cutoff',
@@ -361,6 +366,7 @@
361366
'llamacpp_HF': {
362367
'temperature',
363368
'top_p',
369+
'min_p',
364370
'top_k',
365371
'typical_p',
366372
'epsilon_cutoff',
@@ -399,6 +405,7 @@
399405
'AutoAWQ': {
400406
'temperature',
401407
'top_p',
408+
'min_p',
402409
'top_k',
403410
'typical_p',
404411
'epsilon_cutoff',

modules/presets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def default_preset():
99
'do_sample': True,
1010
'temperature': 1,
1111
'top_p': 1,
12+
'min_p': 1,
1213
'top_k': 0,
1314
'typical_p': 1,
1415
'epsilon_cutoff': 0,

modules/sampler_hijack.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@
1212

1313
global_scores = None
1414

15+
class MinPLogitsWarper(LogitsWarper):
16+
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
17+
if min_p < 0 or min_p > 1.0:
18+
raise ValueError(f"`min_p` has to be a float >= 0 and <= 1, but is {min_p}")
19+
self.min_p = min_p
20+
self.filter_value = filter_value
21+
self.min_tokens_to_keep = min_tokens_to_keep
22+
23+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
24+
# Convert logits to probabilities
25+
probs = torch.softmax(scores, dim=-1)
26+
# Get the probability of the top token for each sequence in the batch
27+
top_probs, _ = probs.max(dim=-1, keepdim=True)
28+
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
29+
scaled_min_p = self.min_p * top_probs
30+
# Create a mask for tokens that have a probability less than the scaled min_p
31+
tokens_to_remove = probs < scaled_min_p
32+
33+
sorted_indices = torch.argsort(scores, descending=True, dim=-1)
34+
sorted_indices_to_remove = torch.gather(tokens_to_remove, dim=-1, index=sorted_indices)
35+
36+
if self.min_tokens_to_keep > 1:
37+
# Keep at least min_tokens_to_keep
38+
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = False
39+
40+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
41+
scores = scores.masked_fill(indices_to_remove, self.filter_value)
42+
return scores
1543

1644
class TailFreeLogitsWarper(LogitsWarper):
1745
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@@ -190,6 +218,8 @@ def get_logits_warper_patch(self, generation_config):
190218
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
191219
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
192220
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
221+
if generation_config.min_p is not None and 0.0 <= generation_config.min_p <= 1.0:
222+
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
193223

194224
if warpers and isinstance(warpers[-1], LogitNormalization):
195225
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
@@ -223,6 +253,7 @@ def get_logits_processor_patch(self, **kwargs):
223253
def generation_config_init_patch(self, **kwargs):
224254
self.__init___old(**kwargs)
225255
self.tfs = kwargs.pop("tfs", 1.0)
256+
self.min_p = kwargs.pop("min_p", 0.0)
226257
self.top_a = kwargs.pop("top_a", 0.0)
227258
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
228259
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)

modules/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def apply_stopping_strings(reply, all_stop_strings):
274274

275275
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
276276
generate_params = {}
277-
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
277+
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
278278
generate_params[k] = state[k]
279279

280280
if state['negative_prompt'] != '':

modules/ui.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def list_interface_input_elements():
105105
'seed',
106106
'temperature',
107107
'top_p',
108+
'min_p',
108109
'top_k',
109110
'typical_p',
110111
'epsilon_cutoff',

modules/ui_parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def create_ui(default_preset):
2929
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
3030
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
3131
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
32+
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=generate_params['min_p'], step=0.01, label='min_p')
3233
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
3334
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
3435
shared.gradio['presence_penalty'] = gr.Slider(-2, 2, value=generate_params['presence_penalty'], step=0.05, label='presence_penalty')

0 commit comments

Comments
 (0)