Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a1b8d41
initial commit
jmamou Mar 13, 2025
da0a7dd
Merge branch 'main' into prune_LMHead
jmamou Mar 17, 2025
6786356
fix
jmamou Mar 24, 2025
e5611c4
Merge branch 'main' into prune_LMHead
jmamou Mar 24, 2025
88ee824
fix style
jmamou Mar 24, 2025
07ead6e
set default to prune
jmamou Mar 25, 2025
12ea728
Merge branch 'main' into prune_LMHead
jmamou Mar 25, 2025
b06d698
add tests
jmamou Mar 25, 2025
c310795
Merge branch 'main' into prune_LMHead
jmamou Mar 25, 2025
3ca6af0
Merge branch 'main' into prune_LMHead
jmamou Mar 25, 2025
763d284
comment
jmamou Mar 25, 2025
c28144e
Merge branch 'main' into prune_LMHead
jmamou Mar 30, 2025
6b1e514
Merge branch 'main' into prune_LMHead
jmamou Apr 1, 2025
f20812f
Merge branch 'main' into prune_LMHead
jmamou Apr 1, 2025
34c2d00
remove prune flag from generate
jmamou Apr 1, 2025
5957c7c
Merge branch 'main' into prune_LMHead
jmamou Apr 1, 2025
ffdff86
Merge branch 'main' into prune_LMHead
jmamou Apr 1, 2025
7767d72
Merge branch 'main' into prune_LMHead
jmamou Apr 2, 2025
f57ae18
address Joao's comments
jmamou Apr 3, 2025
515911f
Merge branch 'main' into prune_LMHead
jmamou Apr 3, 2025
ed10e3c
deprecate_kwarg
jmamou Apr 3, 2025
98965b2
add doc
jmamou Apr 3, 2025
2883380
Merge branch 'main' into prune_LMHead
jmamou Apr 3, 2025
77a4b0b
fix target_vocab_size
jmamou Apr 3, 2025
df3bd2d
Merge branch 'main' into prune_LMHead
jmamou Apr 3, 2025
e5055c1
Merge branch 'main' into prune_LMHead
jmamou Apr 7, 2025
acf7c88
Update src/transformers/generation/candidate_generator.py
jmamou Apr 7, 2025
ac833ce
Update src/transformers/generation/candidate_generator.py
jmamou Apr 7, 2025
9d2a1e3
Update src/transformers/generation/candidate_generator.py
jmamou Apr 7, 2025
d343966
Update src/transformers/generation/candidate_generator.py
jmamou Apr 7, 2025
c9e844f
fix deprecated argument assistant_model_device
jmamou Apr 8, 2025
773661b
Merge branch 'main' into prune_LMHead
jmamou Apr 8, 2025
1ae859e
Merge branch 'main' into prune_LMHead
jmamou Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

import numpy as np
import torch
import torch.nn as nn

from ..pytorch_utils import prune_linear_layer
from ..utils import is_sklearn_available


Expand Down Expand Up @@ -612,6 +614,63 @@ def _process_assistant_outputs(
return new_target_ids


class PruneReindexingLMHead(nn.Module):
"""
A class to prune and reindex the language model head.

This class prunes the language model head to only include the specified token IDs and reindexes the logits
to map back to the original vocabulary.

Args:
original_lm_head (nn.Module): The original language model head.
token_ids (list[int]): The list of token IDs to keep.
"""

def __init__(self, original_lm_head, assistant_overlap_token_ids):
super().__init__()
self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to(
original_lm_head.weight.dtype
)

def forward(self, hidden_states):
pruned_logits = self.pruned_lm_head(hidden_states)
return pruned_logits


class MapInputEmbedding(nn.Module):
def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids):
"""
Wraps an existing embedding layer and remaps token IDs before lookup.

Args:
original_embedding (nn.Embedding): Pre-trained or existing embedding layer.
assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs.
Example: {old_id: new_id}
"""
super().__init__()
self.original_embedding = original_embedding
self.weight = original_embedding.weight
self.assistant_overlap_token_ids = assistant_overlap_token_ids
self.map = False

def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
"""
Args:
input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len).

Returns:
torch.FloatTensor: Corresponding input embeddings.
"""
if self.map:
# Get the last item from input_ids
my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0)
else:
self.map = True
my_input_ids = input_ids

return self.original_embedding(my_input_ids)


class AssistantToTargetTranslator:
"""
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
Expand All @@ -638,23 +697,46 @@ def __init__(
self,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
assistant_model_device: str = "cpu",
assistant_model: "PreTrainedModel",
target_vocab_size: Optional[int],
assistant_prune_LM_head: Optional[bool] = True,
):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
self._assistant_model_device: str = assistant_model_device
self._assistant_model_device: str = assistant_model.device
self.target_vocab_size: int = target_vocab_size
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
self._get_assistant_to_target_input_ids()
)
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: Optional[LogitsProcessorList] = None
self.assistant_prune_LM_head = assistant_prune_LM_head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assistant_prune_LM_head = assistant_prune_LM_head
self.assistant_prune_lm_head = assistant_prune_lm_head

snake case :)

if len(self._suppress_input_ids) > 0:
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
self.logits_processors = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)
# the assistant vocab is not a subset of the target vocab
if self.assistant_prune_LM_head:
self.assistant_overlap_token_ids = torch.tensor(
list(self.target_to_assistant_input_ids.values()),
dtype=torch.long,
device=self._assistant_model_device,
)
original_lm_head = assistant_model.get_output_embeddings()
pruned_lm_head = PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids)
del original_lm_head
assistant_model.set_output_embeddings(pruned_lm_head)

original_input_embeddings = assistant_model.get_input_embeddings()
map_input_embeddings = MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids)
del original_input_embeddings
assistant_model.set_input_embeddings(map_input_embeddings)
self.map_input_embeddings = map_input_embeddings
else:
self.logits_processors = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)

def set_unmap(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docs to public method: it should contain what it does and why it is needed

if self.assistant_prune_LM_head:
self.map_input_embeddings.map = False

def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
Expand Down Expand Up @@ -710,7 +792,12 @@ def get_target_ids(
if num_new_tokens == 0:
return target_input_ids
else:
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
# Get last `num_new_tokens` candidate IDs
last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:]
if self.assistant_prune_LM_head:
# Map assistant IDs -> target input IDs
last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids]
transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids]
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)

def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
Expand All @@ -724,10 +811,12 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]

target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]

if self.assistant_prune_LM_head:
target_logits[..., target_logits_supported_indices] = assistant_logits
else:
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
return target_logits


Expand All @@ -744,8 +833,9 @@ def get_translator(
cls,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int,
assistant_model_device: str = "cpu",
assistant_model: "PreTrainedModel",
target_vocab_size: Optional[int] = None,
assistant_prune_LM_head=True,
) -> AssistantToTargetTranslator:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
Expand All @@ -755,7 +845,7 @@ def get_translator(
mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
target_tokenizer, assistant_tokenizer, assistant_model, target_vocab_size, assistant_prune_LM_head
)
assistant_dict[assistant_tokenizer] = mapping

Expand Down Expand Up @@ -892,7 +982,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)

self._atm_translator.set_unmap()
return assistant_input_ids, len(assistant_new_ids[0])


Expand Down
8 changes: 7 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,14 @@ def _get_candidate_generator(
elif different_tokenizers:
if generation_config.do_sample is True:
atm_translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device
target_tokenizer,
assistant_tokenizer,
assistant_model,
self.config.vocab_size,
assistant_prune_LM_head=True, # prune LM head of assistant model
)
# Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index
assistant_model.generation_config.repetition_penalty = None
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
Expand Down
Loading