Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 32 additions & 34 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,26 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path


def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.

Paper: https://arxiv.org/abs/1909.05858

Args:
logits (mx.array): The logits produced by the language model.
generated_tokens (any): A list of N previous tokens.
tokens (mx.array): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.

Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(generated_tokens) > 0:
indices = mx.array([token for token in generated_tokens])
selected_logits = logits[:, indices]
if len(tokens) > 0:
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, indices] = selected_logits
logits[:, tokens] = selected_logits
return logits


Expand Down Expand Up @@ -158,7 +157,7 @@ def generate_step(
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Expand All @@ -182,8 +181,8 @@ def generate_step(
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
A function that takes tokens and logits and returns the processed
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.

Yields:
Expand Down Expand Up @@ -213,6 +212,27 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]:
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)

logits_processor = logits_processor or []

if repetition_penalty:

def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)

logits_processor.append(repetition_penalty_processor)

if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))

def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits

logits_processor.append(logit_bias_processor)

y = prompt
tokens = None

Expand All @@ -229,40 +249,18 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]:
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])

repetition_context = prompt.tolist()

if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]

if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))

def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]

if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
logits = logits_processor(tokens, logits)

if logit_bias:
logits[:, indices] += values

if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
y, logprobs = sample(logits)
repetition_context.append(y.item())
else:
y, logprobs = sample(logits)
for processor in logits_processor:
logits = processor(tokens, logits)

if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
y, logprobs = sample(logits)
return y, logprobs.squeeze(0)

while y.size > prefill_step_size:
Expand Down
2 changes: 1 addition & 1 deletion llms/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def logits_processor(toks, logits):
"hello",
max_tokens=5,
verbose=False,
logits_processor=logits_processor,
logits_processor=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)

Expand Down