From 39e5152ed8f13ecc421f1c9bd5ddb6f36ebd42d4 Mon Sep 17 00:00:00 2001 From: Nathan Ranchin Date: Sun, 29 Sep 2024 18:01:37 +0200 Subject: [PATCH 1/2] refactor of repetition_penalty and logits_bias to use logits_processor --- llms/mlx_lm/utils.py | 51 +++++++++++++++---------------------- llms/tests/test_generate.py | 2 +- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 16271c3ef..536371ba4 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -158,7 +158,7 @@ def generate_step( max_kv_size: Optional[int] = None, cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None, + logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = [], ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -182,9 +182,9 @@ def generate_step( max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. logit_bias (dictionary, optional): Additive logit bias. - logits_processor (Callable[[mx.array, mx.array], mx.array], optional): - A function that takes tokens and logits and returns the processed - logits. Default: ``None``. + logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``[]``. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -212,6 +212,19 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: raise ValueError( f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) + + if repetition_penalty: + def repetition_penalty_processor(tokens: mx.array, logits: mx.array) -> mx.array: + return apply_repetition_penalty(logits, tokens[-repetition_context_size:], repetition_penalty) + logits_processor.append(repetition_penalty_processor) + + if logit_bias: + def logit_bias_processor(_: mx.array, logits: mx.array) -> mx.array: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + logits[:, indices] += values + return logits + logits_processor.append(logit_bias_processor) y = prompt tokens = None @@ -229,40 +242,18 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: c.update_and_fetch(h[0], h[1]) mx.eval([c.state for c in cache]) - repetition_context = prompt.tolist() - - if repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - def _step(y): - nonlocal repetition_context logits = model(y[None], cache=cache) logits = logits[:, -1, :] if logits_processor: nonlocal tokens tokens = mx.concat([tokens, y]) if tokens is not None else y - logits = logits_processor(tokens, logits) - - if logit_bias: - logits[:, indices] += values - - if repetition_penalty: - logits = apply_repetition_penalty( - logits, repetition_context, repetition_penalty - ) - y, logprobs = sample(logits) - repetition_context.append(y.item()) - else: - y, logprobs = sample(logits) + + for processor in logits_processor: + logits = processor(tokens, logits) - if repetition_context_size: - if len(repetition_context) > repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] + y, logprobs = sample(logits) return y, logprobs.squeeze(0) while y.size > prefill_step_size: diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index bc9698445..68f1670bd 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -46,7 +46,7 @@ def logits_processor(toks, logits): "hello", max_tokens=5, verbose=False, - logits_processor=logits_processor, + logits_processor=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) From 1209d4357dab632896591ff3b105924b667d82ac Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 30 Sep 2024 08:31:09 -0700 Subject: [PATCH 2/2] nits --- llms/mlx_lm/utils.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 536371ba4..d2d4a7d50 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -101,7 +101,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float): +def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): """ Apply repetition penalty to specific logits based on the given context. @@ -109,19 +109,18 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f Args: logits (mx.array): The logits produced by the language model. - generated_tokens (any): A list of N previous tokens. + tokens (mx.array): A list of N previous tokens. penalty (float): The repetition penalty factor to be applied. Returns: logits (mx.array): Logits with repetition penalty applied to generated tokens. """ - if len(generated_tokens) > 0: - indices = mx.array([token for token in generated_tokens]) - selected_logits = logits[:, indices] + if len(tokens) > 0: + selected_logits = logits[:, tokens] selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty ) - logits[:, indices] = selected_logits + logits[:, tokens] = selected_logits return logits @@ -158,7 +157,7 @@ def generate_step( max_kv_size: Optional[int] = None, cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = [], + logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -184,7 +183,7 @@ def generate_step( logit_bias (dictionary, optional): Additive logit bias. logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed - logits. Default: ``[]``. + logits. Default: ``None``. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -212,18 +211,26 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: raise ValueError( f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) - + + logits_processor = logits_processor or [] + if repetition_penalty: - def repetition_penalty_processor(tokens: mx.array, logits: mx.array) -> mx.array: - return apply_repetition_penalty(logits, tokens[-repetition_context_size:], repetition_penalty) + + def repetition_penalty_processor(tokens, logits): + return apply_repetition_penalty( + logits, tokens[-repetition_context_size:], repetition_penalty + ) + logits_processor.append(repetition_penalty_processor) if logit_bias: - def logit_bias_processor(_: mx.array, logits: mx.array) -> mx.array: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): logits[:, indices] += values return logits + logits_processor.append(logit_bias_processor) y = prompt @@ -249,7 +256,7 @@ def _step(y): if logits_processor: nonlocal tokens tokens = mx.concat([tokens, y]) if tokens is not None else y - + for processor in logits_processor: logits = processor(tokens, logits)