-
Notifications
You must be signed in to change notification settings - Fork 31.9k
prune LM Head for USD #36695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
prune LM Head for USD #36695
Changes from 11 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
a1b8d41
initial commit
jmamou da0a7dd
Merge branch 'main' into prune_LMHead
jmamou 6786356
fix
jmamou e5611c4
Merge branch 'main' into prune_LMHead
jmamou 88ee824
fix style
jmamou 07ead6e
set default to prune
jmamou 12ea728
Merge branch 'main' into prune_LMHead
jmamou b06d698
add tests
jmamou c310795
Merge branch 'main' into prune_LMHead
jmamou 3ca6af0
Merge branch 'main' into prune_LMHead
jmamou 763d284
comment
jmamou c28144e
Merge branch 'main' into prune_LMHead
jmamou 6b1e514
Merge branch 'main' into prune_LMHead
jmamou f20812f
Merge branch 'main' into prune_LMHead
jmamou 34c2d00
remove prune flag from generate
jmamou 5957c7c
Merge branch 'main' into prune_LMHead
jmamou ffdff86
Merge branch 'main' into prune_LMHead
jmamou 7767d72
Merge branch 'main' into prune_LMHead
jmamou f57ae18
address Joao's comments
jmamou 515911f
Merge branch 'main' into prune_LMHead
jmamou ed10e3c
deprecate_kwarg
jmamou 98965b2
add doc
jmamou 2883380
Merge branch 'main' into prune_LMHead
jmamou 77a4b0b
fix target_vocab_size
jmamou df3bd2d
Merge branch 'main' into prune_LMHead
jmamou e5055c1
Merge branch 'main' into prune_LMHead
jmamou acf7c88
Update src/transformers/generation/candidate_generator.py
jmamou ac833ce
Update src/transformers/generation/candidate_generator.py
jmamou 9d2a1e3
Update src/transformers/generation/candidate_generator.py
jmamou d343966
Update src/transformers/generation/candidate_generator.py
jmamou c9e844f
fix deprecated argument assistant_model_device
jmamou 773661b
Merge branch 'main' into prune_LMHead
jmamou 1ae859e
Merge branch 'main' into prune_LMHead
jmamou File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -19,7 +19,9 @@ | |||||
|
|
||||||
| import numpy as np | ||||||
| import torch | ||||||
| import torch.nn as nn | ||||||
|
|
||||||
| from ..pytorch_utils import prune_linear_layer | ||||||
| from ..utils import is_sklearn_available | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -612,6 +614,63 @@ def _process_assistant_outputs( | |||||
| return new_target_ids | ||||||
|
|
||||||
|
|
||||||
| class PruneReindexingLMHead(nn.Module): | ||||||
| """ | ||||||
| A class to prune and reindex the language model head. | ||||||
|
|
||||||
| This class prunes the language model head to only include the specified token IDs and reindexes the logits | ||||||
| to map back to the original vocabulary. | ||||||
|
|
||||||
| Args: | ||||||
| original_lm_head (nn.Module): The original language model head. | ||||||
| token_ids (list[int]): The list of token IDs to keep. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, original_lm_head, assistant_overlap_token_ids): | ||||||
| super().__init__() | ||||||
| self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to( | ||||||
| original_lm_head.weight.dtype | ||||||
| ) | ||||||
|
|
||||||
| def forward(self, hidden_states): | ||||||
| pruned_logits = self.pruned_lm_head(hidden_states) | ||||||
| return pruned_logits | ||||||
|
|
||||||
|
|
||||||
| class MapInputEmbedding(nn.Module): | ||||||
| def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids): | ||||||
| """ | ||||||
| Wraps an existing embedding layer and remaps token IDs before lookup. | ||||||
|
|
||||||
| Args: | ||||||
| original_embedding (nn.Embedding): Pre-trained or existing embedding layer. | ||||||
| assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs. | ||||||
| Example: {old_id: new_id} | ||||||
| """ | ||||||
| super().__init__() | ||||||
| self.original_embedding = original_embedding | ||||||
| self.weight = original_embedding.weight | ||||||
| self.assistant_overlap_token_ids = assistant_overlap_token_ids | ||||||
| self.map = False | ||||||
|
|
||||||
| def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: | ||||||
| """ | ||||||
| Args: | ||||||
| input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len). | ||||||
|
|
||||||
| Returns: | ||||||
| torch.FloatTensor: Corresponding input embeddings. | ||||||
| """ | ||||||
| if self.map: | ||||||
| # Get the last item from input_ids | ||||||
| my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0) | ||||||
| else: | ||||||
| self.map = True | ||||||
| my_input_ids = input_ids | ||||||
|
|
||||||
| return self.original_embedding(my_input_ids) | ||||||
|
|
||||||
|
|
||||||
| class AssistantToTargetTranslator: | ||||||
| """ | ||||||
| Translates token ids and logits between assistant and target model vocabularies. This class is used to handle | ||||||
|
|
@@ -638,23 +697,46 @@ def __init__( | |||||
| self, | ||||||
| target_tokenizer: "PreTrainedTokenizerBase", | ||||||
| assistant_tokenizer: "PreTrainedTokenizerBase", | ||||||
| target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() | ||||||
| assistant_model_device: str = "cpu", | ||||||
| assistant_model: "PreTrainedModel", | ||||||
| target_vocab_size: Optional[int], | ||||||
| assistant_prune_LM_head: Optional[bool] = True, | ||||||
| ): | ||||||
| self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer | ||||||
| self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer | ||||||
| self._assistant_model_device: str = assistant_model_device | ||||||
| self._assistant_model_device: str = assistant_model.device | ||||||
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| self.target_vocab_size: int = target_vocab_size | ||||||
| self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( | ||||||
| self._get_assistant_to_target_input_ids() | ||||||
| ) | ||||||
| self._suppress_input_ids: list[int] = self._get_suppress_input_ids() | ||||||
| self.logits_processors: Optional[LogitsProcessorList] = None | ||||||
| self.assistant_prune_LM_head = assistant_prune_LM_head | ||||||
|
||||||
| self.assistant_prune_LM_head = assistant_prune_LM_head | |
| self.assistant_prune_lm_head = assistant_prune_lm_head |
snake case :)
Outdated
Contributor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docs to public method: it should contain what it does and why it is needed
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.