Skip to content

Add logits_processor option to generate_step function#983

Merged
awni merged 8 commits intoml-explore:mainfrom
nathanrchn:logits_processor
Sep 28, 2024
Merged

Add logits_processor option to generate_step function#983
awni merged 8 commits intoml-explore:mainfrom
nathanrchn:logits_processor

Conversation

@nathanrchn
Copy link
Copy Markdown
Contributor

This update introduces token masking capabilities to the generate_step function via a new logits_processor parameter. This enhancement supports constrained decoding scenarios that require token masking prior to sampling.

Updates include:

  1. New logits_processor parameter in generate_step function
  2. Token masking logic implemented within _step function
  3. Updated docstring for generate_step to describe logits_processor
  4. Created mx.array of all tokens including prompt

Usage example:

def logits_processor(input_ids: mx.array, logits: mx.array) -> mx.array:
    return grammar_processor(input_ids, logits)

Here, grammar_processor could represent a custom constrained decoding approach.

@awni
Copy link
Copy Markdown
Member

awni commented Sep 27, 2024

Sorry for the delayed review. This is cool! And I think we can include it but I want to clarify something first. We have the logit_bias argument already which can do some of what the logit_processor does but in a less flexible way. Is that insufficient for your use cases? If so.. could you explain?

In addition, I don't think we need to support both arguments as it's a bit messy and redundant. Perhaps we can remove the logit_bias and keep the more flexible logit_processor if needed.

@nathanrchn
Copy link
Copy Markdown
Contributor Author

I can indeed use the logit_bias argument but as my project primarily uses the transformers library, it would be nice to have the same logit_processor. Additionally, I would need to rewrite the stream_generate or generate method because the mask changes with each forward pass. Moreover, when masking tokens for constrained decoding, the majority of tokens are typically masked. Creating the logit_bias dictionary might slow down the generation process slightly, as it needs to cover thousands of tokens.

If you agree, I can remove the logit_bias from the arguments.

@awni
Copy link
Copy Markdown
Member

awni commented Sep 28, 2024

Sounds good, thanks for clarifying. Let's remove logit_bias in favor of logit_processor then. Thanks!

else:
y, logprobs = sample(logits)

tokens_ids = mx.concat([tokens_ids, y], axis=0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug? Shouldn't it be tokens?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes of course.

Copy link
Copy Markdown
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition!

@awni
Copy link
Copy Markdown
Member

awni commented Sep 28, 2024

I added logit_bias back because it's part of the OpenAI API spec and I don't want to break compatability with that.

As a follow up, a nice thing to do would be to refactor out logit_bias and repition_penalty out of generate_step since they can both use the logit_processor. That would simplify generate_step quite nicely.

@awni awni merged commit ace2bb5 into ml-explore:main Sep 28, 2024
@nathanrchn
Copy link
Copy Markdown
Contributor Author

Do you mean adding logit_bias and repition_penalty as argument for the generation and stream_generate and create logits_processor method to handle the biases and the repetition penalty?

If so, it might be beneficial to modify the logits_processor type from a single function to a list of functions to apply sequentially.

@awni
Copy link
Copy Markdown
Member

awni commented Sep 28, 2024

I hadn't thought it through too carefully but yes what your describing is more or less what I had in mind.

If so, it might be beneficial to modify the logits_processor type from a single function to a list of functions to apply sequentially.

Yea .. that may be cleaner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants