From a1b8d416cb86a7e98d3f6da578494b114b94e21f Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 13 Mar 2025 04:16:58 -0700 Subject: [PATCH 01/16] initial commit --- .../generation/candidate_generator.py | 125 +++++++++++++++--- .../generation/configuration_utils.py | 4 + src/transformers/generation/utils.py | 3 +- 3 files changed, 116 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ebb9a18d559f..fa85074740ec 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -19,7 +19,9 @@ import numpy as np import torch +import torch.nn as nn +from ..pytorch_utils import prune_linear_layer from ..utils import is_sklearn_available @@ -611,6 +613,65 @@ def _process_assistant_outputs( return new_target_ids +class PruneReindexingLMHead(nn.Module): + """ + A class to prune and reindex the language model head. + + This class prunes the language model head to only include the specified token IDs and reindexes the logits + to map back to the original vocabulary. + + Args: + original_lm_head (nn.Module): The original language model head. + token_ids (list[int]): The list of token IDs to keep. + filter_value (float, optional): The value to use for filtering out pruned logits. Defaults to -float("Inf"). + """ + def __init__(self, original_lm_head, token_ids, filter_value: float = -float("Inf")): + super().__init__() + self.token_ids = token_ids + self.filter_value = filter_value + self.original_vocab_size = original_lm_head.out_features + self.pruned_lm_head = prune_linear_layer(original_lm_head, self.token_ids).to(original_lm_head.weight.dtype) + #print(f'{original_lm_head=}') + #print(f'{self.pruned_lm_head=}') + + def forward(self, hidden_states): + pruned_logits = self.pruned_lm_head(hidden_states) + #print(f'{torch.argmax(pruned_logits)=}') + return pruned_logits + +class MapInputEmbedding(nn.Module): + def __init__(self, original_embedding: nn.Embedding, token_ids): + """ + Wraps an existing embedding layer and remaps token IDs before lookup. + + Args: + original_embedding (nn.Embedding): Pre-trained or existing embedding layer. + id_map (dict): Mapping from original token IDs to new token IDs. + Example: {old_id: new_id} + """ + super().__init__() + self.original_embedding = original_embedding + self.token_ids = token_ids + self.first = True + + def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: + """ + Args: + input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len). + + Returns: + torch.FloatTensor: Corresponding input embeddings. + """ + #print(f'A {input_ids.squeeze(0).tolist()=}') + #print(f'{self.first}') + if self.first: #input_ids.shape[-1] > 1: + self.first = False + else: + # Get the last item from input_ids + input_ids[0, -1] = self.token_ids[input_ids[0, -1]] + + #print(f'B {input_ids.squeeze(0).tolist()=}') + return self.original_embedding(input_ids) class AssistantToTargetTranslator: """ @@ -638,23 +699,49 @@ def __init__( self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() - assistant_model_device: str = "cpu", + assistant_model: "PreTrainedModel", + target_vocab_size: Optional[int], + assistant_prune_LM_head: Optional[bool] = False ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device: str = assistant_model_device + self._assistant_model_device: str = assistant_model.device self.target_vocab_size: int = target_vocab_size self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( self._get_assistant_to_target_input_ids() ) self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: Optional[LogitsProcessorList] = None + self.assistant_prune_LM_head = assistant_prune_LM_head if len(self._suppress_input_ids) > 0: - # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab - self.logits_processors = LogitsProcessorList( - [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] - ) + # the assistant vocab is not a subset of the target vocab + if assistant_prune_LM_head: + + self.assistant_overlap_token_ids = torch.tensor( + list(self.target_to_assistant_input_ids.values()), dtype=torch.long, device=self._assistant_model_device + ) + original_lm_head = assistant_model.get_output_embeddings() + pruned_lm_head = PruneReindexingLMHead( + original_lm_head, self.assistant_overlap_token_ids, self.FILTER_VALUE + ) + del original_lm_head + assistant_model.set_output_embeddings(pruned_lm_head) + + originial_input_embeddings = assistant_model.get_input_embeddings() + map_input_embeddings = MapInputEmbedding( + originial_input_embeddings, self.assistant_overlap_token_ids + ) + assistant_model.set_input_embeddings(map_input_embeddings) + self._assistant_model = assistant_model + del originial_input_embeddings + else: + self.logits_processors = LogitsProcessorList( + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] + ) + + def set_first(self): + if self.assistant_prune_LM_head: + self._assistant_model.get_input_embeddings().first = True def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() @@ -710,7 +797,12 @@ def get_target_ids( if num_new_tokens == 0: return target_input_ids else: - transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] + # Get last `num_new_tokens` candidate IDs + last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:] + if self.assistant_prune_LM_head: + # Map assistant IDs -> target input IDs + last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids] + transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids] return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: @@ -724,10 +816,12 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID # Exclude invalid indices target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] - valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] - - target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] + if self.assistant_prune_LM_head: + target_logits[..., target_logits_supported_indices] = assistant_logits + else: + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] return target_logits @@ -744,8 +838,9 @@ def get_translator( cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - target_vocab_size: int, - assistant_model_device: str = "cpu", + assistant_model: "PreTrainedModel", + target_vocab_size: Optional[int] = None, + assistant_prune_LM_head = False ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) if assistant_dict is None: @@ -755,7 +850,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device + target_tokenizer, assistant_tokenizer, assistant_model, target_vocab_size, assistant_prune_LM_head ) assistant_dict[assistant_tokenizer] = mapping @@ -892,7 +987,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(dtype=torch.long) - + self._atm_translator.set_first() return assistant_input_ids, len(assistant_new_ids[0]) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 6ee48ab3f1a5..4e813f6bc3ab 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -372,6 +372,8 @@ class GenerationConfig(PushToHubMixin): If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens to correctly align tokens. Can only be used with different tokenizers in speculative decoding. See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details. + assistant_prune_LM_head(`bool`, *optional*, defaults to `True`): + If set to `True`, LM head of the assistant model will be pruned. Can only be used with different tokenizers in speculative decoding with `do_sample=True`. > Parameters related to performances and compilation @@ -482,6 +484,8 @@ def __init__(self, **kwargs): ## assistant generation for different tokenizers, the windows size for assistant/target model self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10) self.target_lookbehind = kwargs.pop("target_lookbehind", 10) + ## assistant generation for different tokenizers, pruning of the LM head of the assistant model + self.assistant_prune_LM_head = kwargs.pop("assistant_prune_LM_head", True) # Performance self.compile_config = kwargs.pop("compile_config", CompileConfig()) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8d5be7d7a05e..a43131aa79c6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -863,7 +863,8 @@ def _get_candidate_generator( elif different_tokenizers: if generation_config.do_sample is True: atm_translator = AssistantVocabTranslatorCache.get_translator( - target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device + target_tokenizer, assistant_tokenizer, assistant_model, + self.config.vocab_size, assistant_prune_LM_head=generation_config.assistant_prune_LM_head ) candidate_generator = UniversalSpeculativeDecodingGenerator( input_ids=input_ids, From 67863567b7d617a91b3066a40eb47a8b8610e661 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 24 Mar 2025 02:43:46 -0700 Subject: [PATCH 02/16] fix --- .../generation/candidate_generator.py | 47 +++++++++---------- src/transformers/generation/utils.py | 1 + 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index fa85074740ec..c4a0447e0ced 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -623,24 +623,17 @@ class PruneReindexingLMHead(nn.Module): Args: original_lm_head (nn.Module): The original language model head. token_ids (list[int]): The list of token IDs to keep. - filter_value (float, optional): The value to use for filtering out pruned logits. Defaults to -float("Inf"). """ - def __init__(self, original_lm_head, token_ids, filter_value: float = -float("Inf")): + def __init__(self, original_lm_head, assistant_overlap_token_ids): super().__init__() - self.token_ids = token_ids - self.filter_value = filter_value - self.original_vocab_size = original_lm_head.out_features - self.pruned_lm_head = prune_linear_layer(original_lm_head, self.token_ids).to(original_lm_head.weight.dtype) - #print(f'{original_lm_head=}') - #print(f'{self.pruned_lm_head=}') + self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to(original_lm_head.weight.dtype) def forward(self, hidden_states): pruned_logits = self.pruned_lm_head(hidden_states) - #print(f'{torch.argmax(pruned_logits)=}') return pruned_logits class MapInputEmbedding(nn.Module): - def __init__(self, original_embedding: nn.Embedding, token_ids): + def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids): """ Wraps an existing embedding layer and remaps token IDs before lookup. @@ -651,8 +644,9 @@ def __init__(self, original_embedding: nn.Embedding, token_ids): """ super().__init__() self.original_embedding = original_embedding - self.token_ids = token_ids - self.first = True + self.weight = original_embedding.weight + self.assistant_overlap_token_ids = assistant_overlap_token_ids + self.map = False def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: """ @@ -663,15 +657,16 @@ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: torch.FloatTensor: Corresponding input embeddings. """ #print(f'A {input_ids.squeeze(0).tolist()=}') - #print(f'{self.first}') - if self.first: #input_ids.shape[-1] > 1: - self.first = False + #print(f'{self.map}') + if self.map: + # Get the last item from input_ids + my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0) else: - # Get the last item from input_ids - input_ids[0, -1] = self.token_ids[input_ids[0, -1]] + self.map = True + my_input_ids = input_ids #print(f'B {input_ids.squeeze(0).tolist()=}') - return self.original_embedding(input_ids) + return self.original_embedding(my_input_ids) class AssistantToTargetTranslator: """ @@ -715,14 +710,14 @@ def __init__( self.assistant_prune_LM_head = assistant_prune_LM_head if len(self._suppress_input_ids) > 0: # the assistant vocab is not a subset of the target vocab - if assistant_prune_LM_head: + if self.assistant_prune_LM_head: self.assistant_overlap_token_ids = torch.tensor( list(self.target_to_assistant_input_ids.values()), dtype=torch.long, device=self._assistant_model_device ) original_lm_head = assistant_model.get_output_embeddings() pruned_lm_head = PruneReindexingLMHead( - original_lm_head, self.assistant_overlap_token_ids, self.FILTER_VALUE + original_lm_head, self.assistant_overlap_token_ids ) del original_lm_head assistant_model.set_output_embeddings(pruned_lm_head) @@ -731,17 +726,17 @@ def __init__( map_input_embeddings = MapInputEmbedding( originial_input_embeddings, self.assistant_overlap_token_ids ) - assistant_model.set_input_embeddings(map_input_embeddings) - self._assistant_model = assistant_model del originial_input_embeddings + assistant_model.set_input_embeddings(map_input_embeddings) + self.map_input_embeddings = map_input_embeddings else: self.logits_processors = LogitsProcessorList( [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] ) - def set_first(self): + def set_unmap(self): if self.assistant_prune_LM_head: - self._assistant_model.get_input_embeddings().first = True + self.map_input_embeddings.map = False def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() @@ -973,7 +968,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to ) assistant_new_ids = self.assistant_tokenizer( target_new_text, add_special_tokens=False, return_tensors="pt" - )["input_ids"].to(self.assistant_model.device) + )["input_ids"].to(self.assistant_model.device) else: assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device) @@ -987,7 +982,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(dtype=torch.long) - self._atm_translator.set_first() + self._atm_translator.set_unmap() return assistant_input_ids, len(assistant_new_ids[0]) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 716dc894c6bd..0c15ad8283b7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -866,6 +866,7 @@ def _get_candidate_generator( target_tokenizer, assistant_tokenizer, assistant_model, self.config.vocab_size, assistant_prune_LM_head=generation_config.assistant_prune_LM_head ) + assistant_model.generation_config.repetition_penalty = None candidate_generator = UniversalSpeculativeDecodingGenerator( input_ids=input_ids, assistant_model=assistant_model, From 88ee8248be71d256175a87f4026746aa68c081e1 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 24 Mar 2025 02:48:08 -0700 Subject: [PATCH 03/16] fix style --- .../generation/candidate_generator.py | 36 +++++++++---------- src/transformers/generation/utils.py | 7 ++-- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index c4a0447e0ced..7804399806f7 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -613,6 +613,7 @@ def _process_assistant_outputs( return new_target_ids + class PruneReindexingLMHead(nn.Module): """ A class to prune and reindex the language model head. @@ -624,19 +625,23 @@ class PruneReindexingLMHead(nn.Module): original_lm_head (nn.Module): The original language model head. token_ids (list[int]): The list of token IDs to keep. """ + def __init__(self, original_lm_head, assistant_overlap_token_ids): super().__init__() - self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to(original_lm_head.weight.dtype) + self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to( + original_lm_head.weight.dtype + ) def forward(self, hidden_states): pruned_logits = self.pruned_lm_head(hidden_states) return pruned_logits + class MapInputEmbedding(nn.Module): def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids): """ Wraps an existing embedding layer and remaps token IDs before lookup. - + Args: original_embedding (nn.Embedding): Pre-trained or existing embedding layer. id_map (dict): Mapping from original token IDs to new token IDs. @@ -652,22 +657,20 @@ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: """ Args: input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len). - + Returns: torch.FloatTensor: Corresponding input embeddings. """ - #print(f'A {input_ids.squeeze(0).tolist()=}') - #print(f'{self.map}') if self.map: - # Get the last item from input_ids + # Get the last item from input_ids my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0) else: self.map = True my_input_ids = input_ids - #print(f'B {input_ids.squeeze(0).tolist()=}') return self.original_embedding(my_input_ids) + class AssistantToTargetTranslator: """ Translates token ids and logits between assistant and target model vocabularies. This class is used to handle @@ -696,7 +699,7 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model: "PreTrainedModel", target_vocab_size: Optional[int], - assistant_prune_LM_head: Optional[bool] = False + assistant_prune_LM_head: Optional[bool] = False, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer @@ -711,21 +714,18 @@ def __init__( if len(self._suppress_input_ids) > 0: # the assistant vocab is not a subset of the target vocab if self.assistant_prune_LM_head: - self.assistant_overlap_token_ids = torch.tensor( - list(self.target_to_assistant_input_ids.values()), dtype=torch.long, device=self._assistant_model_device + list(self.target_to_assistant_input_ids.values()), + dtype=torch.long, + device=self._assistant_model_device, ) original_lm_head = assistant_model.get_output_embeddings() - pruned_lm_head = PruneReindexingLMHead( - original_lm_head, self.assistant_overlap_token_ids - ) + pruned_lm_head = PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids) del original_lm_head assistant_model.set_output_embeddings(pruned_lm_head) originial_input_embeddings = assistant_model.get_input_embeddings() - map_input_embeddings = MapInputEmbedding( - originial_input_embeddings, self.assistant_overlap_token_ids - ) + map_input_embeddings = MapInputEmbedding(originial_input_embeddings, self.assistant_overlap_token_ids) del originial_input_embeddings assistant_model.set_input_embeddings(map_input_embeddings) self.map_input_embeddings = map_input_embeddings @@ -835,7 +835,7 @@ def get_translator( assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model: "PreTrainedModel", target_vocab_size: Optional[int] = None, - assistant_prune_LM_head = False + assistant_prune_LM_head=False, ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) if assistant_dict is None: @@ -968,7 +968,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to ) assistant_new_ids = self.assistant_tokenizer( target_new_text, add_special_tokens=False, return_tensors="pt" - )["input_ids"].to(self.assistant_model.device) + )["input_ids"].to(self.assistant_model.device) else: assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e6a69f9cba8c..2e5aac83b1ca 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -877,8 +877,11 @@ def _get_candidate_generator( elif different_tokenizers: if generation_config.do_sample is True: atm_translator = AssistantVocabTranslatorCache.get_translator( - target_tokenizer, assistant_tokenizer, assistant_model, - self.config.vocab_size, assistant_prune_LM_head=generation_config.assistant_prune_LM_head + target_tokenizer, + assistant_tokenizer, + assistant_model, + self.config.vocab_size, + assistant_prune_LM_head=generation_config.assistant_prune_LM_head, ) assistant_model.generation_config.repetition_penalty = None candidate_generator = UniversalSpeculativeDecodingGenerator( From 07ead6ea36ec58ad670f287426f1d1fdce87c24d Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 25 Mar 2025 01:38:01 -0700 Subject: [PATCH 04/16] set default to prune --- src/transformers/generation/candidate_generator.py | 10 +++++----- src/transformers/generation/utils.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 7804399806f7..fe879fef7490 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -699,7 +699,7 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model: "PreTrainedModel", target_vocab_size: Optional[int], - assistant_prune_LM_head: Optional[bool] = False, + assistant_prune_LM_head: Optional[bool] = True, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer @@ -724,9 +724,9 @@ def __init__( del original_lm_head assistant_model.set_output_embeddings(pruned_lm_head) - originial_input_embeddings = assistant_model.get_input_embeddings() - map_input_embeddings = MapInputEmbedding(originial_input_embeddings, self.assistant_overlap_token_ids) - del originial_input_embeddings + original_input_embeddings = assistant_model.get_input_embeddings() + map_input_embeddings = MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids) + del original_input_embeddings assistant_model.set_input_embeddings(map_input_embeddings) self.map_input_embeddings = map_input_embeddings else: @@ -835,7 +835,7 @@ def get_translator( assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model: "PreTrainedModel", target_vocab_size: Optional[int] = None, - assistant_prune_LM_head=False, + assistant_prune_LM_head=True, ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) if assistant_dict is None: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2e5aac83b1ca..d15709e8d45e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -883,7 +883,9 @@ def _get_candidate_generator( self.config.vocab_size, assistant_prune_LM_head=generation_config.assistant_prune_LM_head, ) - assistant_model.generation_config.repetition_penalty = None + if generation_config.assistant_prune_LM_head: + # If we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index + assistant_model.generation_config.repetition_penalty = None candidate_generator = UniversalSpeculativeDecodingGenerator( input_ids=input_ids, assistant_model=assistant_model, From b06d6980adbd0dc4fb27a992f7b3cb12cdaa4bb1 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 25 Mar 2025 02:28:45 -0700 Subject: [PATCH 05/16] add tests --- tests/generation/test_candidate_generator.py | 133 +++++++++++-------- 1 file changed, 81 insertions(+), 52 deletions(-) diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 38df48ab08d2..67bad4118740 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -20,6 +20,7 @@ def setUp(self): # Create mock tokenizers with predefined vocabularies self.target_tokenizer = MagicMock() self.assistant_tokenizer = MagicMock() + self.assistant_model = MagicMock(device=torch_device) # Define mock vocabularies for the tokenizers self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3} @@ -27,15 +28,15 @@ def setUp(self): self.target_tokenizer.get_vocab.return_value = self.target_vocab self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab - self.assistant_model_device = torch_device self.target_vocab_size = 6 # Instantiate the class under test self.translator = AssistantToTargetTranslator( target_tokenizer=self.target_tokenizer, assistant_tokenizer=self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, + assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_prune_LM_head=False, ) def test_get_assistant_to_target_input_ids(self): @@ -53,19 +54,19 @@ def test_get_suppress_input_ids(self): def test_get_target_ids(self): """Test the translation of assistant candidate IDs to target candidate IDs.""" assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo' in assistant tokenizer target_input_ids = torch.LongTensor([[0, 1, 2]]).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo' in target tokenizer assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo baz' in assistant tokenizer expected_target_ids = torch.LongTensor( [[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]] ).to( - self.assistant_model_device + self.assistant_model.device ) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab) actual_target_ids = self.translator.get_target_ids( @@ -77,12 +78,12 @@ def test_get_target_logits(self): """Test the conversion of assistant logits to target logits.""" # Assistant logits for IDs 0, 1, 2 assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to( - self.assistant_model_device + self.assistant_model.device ) # Shape (1, 1, 5) # Expected target logits (target_vocab_size = 4) expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to( - self.assistant_model_device + self.assistant_model.device ) expected_target_logits[0, 0, 0] = 0.1 # 'hello' expected_target_logits[0, 0, 1] = 0.2 # 'world' @@ -119,7 +120,8 @@ def setUp(self): self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2}) self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3}) self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5}) - self.assistant_model_device = torch_device + self.assistant_model = MagicMock(device=torch_device) + self.target_vocab_size = 6 def test_same_instance_for_same_tokenizers(self): @@ -127,14 +129,16 @@ def test_same_instance_for_same_tokenizers(self): translator1 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, + assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_prune_LM_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, + assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_prune_LM_head=False, ) self.assertIs(translator1, translator2, "Translators should be cached and identical") @@ -143,14 +147,16 @@ def test_different_instances_for_different_tokenizers(self): translator1 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model_device=self.assistant_model_device, + assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_prune_LM_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.other_target_tokenizer, self.other_assistant_tokenizer, - assistant_model_device=self.assistant_model_device, + assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_prune_LM_head=False, ) self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") @@ -164,8 +170,9 @@ def test_cache_with_weakref_key(self): translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer, assistant_tokenizer, - assistant_model_device=self.assistant_model_device, + assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_prune_LM_head=False, ) self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) @@ -192,8 +199,9 @@ def create_translator(): translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer, assistant_tokenizer, - assistant_model_device=self.assistant_model_device, + assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_prune_LM_head=False, ) # Create weak references before returning refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer)) @@ -239,7 +247,7 @@ def setUp(self): self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id if self.assistant_tokenizer.pad_token_id is None: self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id - if self.target_tokenizer.bos_token_id is None: + if self.assistant_tokenizer.bos_token_id is None: self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) @@ -247,8 +255,13 @@ def setUp(self): "attention_mask": torch.ones_like(self.input_ids).to(torch_device), } + def setUpGenerator(self, assistant_prune_LM_head): atm_translator = AssistantVocabTranslatorCache.get_translator( - self.target_tokenizer, self.assistant_tokenizer, self.target_config.vocab_size, torch_device + target_tokenizer=self.target_tokenizer, + assistant_tokenizer=self.assistant_tokenizer, + assistant_model=self.assistant_model, + target_vocab_size=self.target_config.vocab_size, + assistant_prune_LM_head=assistant_prune_LM_head, ) self.generator = UniversalSpeculativeDecodingGenerator( input_ids=self.input_ids, @@ -262,58 +275,62 @@ def setUp(self): def test_basic_generation(self): """Test basic speculative decoding works""" - input_text = "The quick brown fox" - input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt") - self.generator.input_ids = input_ids - candidates, scores = self.generator.get_candidates(input_ids) + for assistant_prune_LM_head in [False, True]: + self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) + input_text = "The quick brown fox" + input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt") + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) - self.assertIsNotNone(candidates) - self.assertIsNotNone(scores) - self.assertTrue(torch.is_tensor(candidates)) - self.assertTrue(torch.is_tensor(scores)) + self.assertIsNotNone(candidates) + self.assertIsNotNone(scores) + self.assertTrue(torch.is_tensor(candidates)) + self.assertTrue(torch.is_tensor(scores)) def test_mismatched_vocabularies(self): """Test handling of mismatched vocabularies between models""" # Create input with tokens present in main but not assistant vocab # Find a token that is not in the assistant tokenizer but in # the main tokenizer. - missing_token = next( - token - for token in self.target_tokenizer.get_vocab() - if token not in self.assistant_tokenizer.get_vocab() - and token not in self.target_tokenizer.all_special_tokens - and "reserved_" not in token - ) - input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) - self.generator.input_ids = input_ids - candidates, scores = self.generator.get_candidates(input_ids) - self.assertIsNotNone(candidates) + for assistant_prune_LM_head in [False, True]: + self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) + missing_token = next( + token + for token in self.target_tokenizer.get_vocab() + if token not in self.assistant_tokenizer.get_vocab() + and token not in self.target_tokenizer.all_special_tokens + and "reserved_" not in token + ) + input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) + self.generator.input_ids = input_ids + candidates, _ = self.generator.get_candidates(input_ids) + self.assertIsNotNone(candidates) def test_speculation_depth(self): """Test different speculation depths""" - input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt") - self.generator.input_ids = input_ids + for assistant_prune_LM_head in [False, True]: + self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) + input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt") + self.generator.input_ids = input_ids - for depth in [1, 8, 17]: - self.generator.num_assistant_tokens = depth - candidates, scores = self.generator.get_candidates(input_ids) - self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) + for depth in [1, 8, 17]: + self.generator.num_assistant_tokens = depth + candidates, _ = self.generator.get_candidates(input_ids) + self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) def test_device_consistency(self): """Test handling of inputs on different devices""" - input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) - self.generator.input_ids = input_ids - candidates, _ = self.generator.get_candidates(input_ids) - self.assertEqual(candidates.device, input_ids.device) + for assistant_prune_LM_head in [False, True]: + self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) + input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) + self.generator.input_ids = input_ids + candidates, _ = self.generator.get_candidates(input_ids) + self.assertEqual(candidates.device, input_ids.device) def test_usd_vs_vanilla_sampling(cls): """Test that USD matches vanilla sampling with temperature set to nearly 0""" prompt = "Test text" - pipe_usd = pipeline("text-generation", model=cls.target_name, assistant_model=cls.assistant_name) - pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature - usd_text = pipe_usd_output[0]["generated_text"] - pipe_vanilla = pipeline( "text-generation", model=cls.target_name, @@ -321,5 +338,17 @@ def test_usd_vs_vanilla_sampling(cls): pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False) vanilla_text = pipe_vanilla_output[0]["generated_text"] - # Assert that the outputs match - cls.assertEqual(usd_text, vanilla_text) + for assistant_prune_LM_head in [False, True]: + pipe_usd = pipeline( + "text-generation", + model=cls.target_name, + assistant_model=cls.assistant_name, + assistant_prune_LM_head=assistant_prune_LM_head, + ) + pipe_usd_output = pipe_usd( + prompt, max_new_tokens=5, do_sample=True, temperature=1e-9 + ) # Nearly 0 temperature + usd_text = pipe_usd_output[0]["generated_text"] + + # Assert that the outputs match + cls.assertEqual(usd_text, vanilla_text) From 763d284a5ef18a9b0388d8dfacea01607efbaa04 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 25 Mar 2025 03:31:13 -0700 Subject: [PATCH 06/16] comment --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index fe879fef7490..770192a2fb5a 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -644,7 +644,7 @@ def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids Args: original_embedding (nn.Embedding): Pre-trained or existing embedding layer. - id_map (dict): Mapping from original token IDs to new token IDs. + assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs. Example: {old_id: new_id} """ super().__init__() From 34c2d0023dac042346462548a95c8d833d19d9a7 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 1 Apr 2025 15:02:37 +0300 Subject: [PATCH 07/16] remove prune flag from generate --- .../generation/configuration_utils.py | 4 - src/transformers/generation/utils.py | 7 +- tests/generation/test_candidate_generator.py | 93 ++++++++----------- 3 files changed, 42 insertions(+), 62 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 5dda2164147d..a6b0a72162fc 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -372,8 +372,6 @@ class GenerationConfig(PushToHubMixin): If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens to correctly align tokens. Can only be used with different tokenizers in speculative decoding. See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details. - assistant_prune_LM_head(`bool`, *optional*, defaults to `True`): - If set to `True`, LM head of the assistant model will be pruned. Can only be used with different tokenizers in speculative decoding with `do_sample=True`. > Parameters related to performances and compilation @@ -483,8 +481,6 @@ def __init__(self, **kwargs): ## assistant generation for different tokenizers, the windows size for assistant/target model self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10) self.target_lookbehind = kwargs.pop("target_lookbehind", 10) - ## assistant generation for different tokenizers, pruning of the LM head of the assistant model - self.assistant_prune_LM_head = kwargs.pop("assistant_prune_LM_head", True) # Performance self.compile_config = kwargs.pop("compile_config", CompileConfig()) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9a78377f6114..274f2cad34a8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -881,11 +881,10 @@ def _get_candidate_generator( assistant_tokenizer, assistant_model, self.config.vocab_size, - assistant_prune_LM_head=generation_config.assistant_prune_LM_head, + assistant_prune_LM_head=True, # prune LM head of assistant model ) - if generation_config.assistant_prune_LM_head: - # If we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index - assistant_model.generation_config.repetition_penalty = None + # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index + assistant_model.generation_config.repetition_penalty = None candidate_generator = UniversalSpeculativeDecodingGenerator( input_ids=input_ids, assistant_model=assistant_model, diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 67bad4118740..1c2839e9c5cd 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -254,14 +254,11 @@ def setUp(self): self.model_kwargs = { "attention_mask": torch.ones_like(self.input_ids).to(torch_device), } - - def setUpGenerator(self, assistant_prune_LM_head): atm_translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer=self.target_tokenizer, assistant_tokenizer=self.assistant_tokenizer, assistant_model=self.assistant_model, target_vocab_size=self.target_config.vocab_size, - assistant_prune_LM_head=assistant_prune_LM_head, ) self.generator = UniversalSpeculativeDecodingGenerator( input_ids=self.input_ids, @@ -275,57 +272,49 @@ def setUpGenerator(self, assistant_prune_LM_head): def test_basic_generation(self): """Test basic speculative decoding works""" - for assistant_prune_LM_head in [False, True]: - self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) - input_text = "The quick brown fox" - input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt") - self.generator.input_ids = input_ids - candidates, scores = self.generator.get_candidates(input_ids) - - self.assertIsNotNone(candidates) - self.assertIsNotNone(scores) - self.assertTrue(torch.is_tensor(candidates)) - self.assertTrue(torch.is_tensor(scores)) + input_text = "The quick brown fox" + input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt") + self.generator.input_ids = input_ids + candidates, scores = self.generator.get_candidates(input_ids) + + self.assertIsNotNone(candidates) + self.assertIsNotNone(scores) + self.assertTrue(torch.is_tensor(candidates)) + self.assertTrue(torch.is_tensor(scores)) def test_mismatched_vocabularies(self): """Test handling of mismatched vocabularies between models""" # Create input with tokens present in main but not assistant vocab # Find a token that is not in the assistant tokenizer but in # the main tokenizer. - for assistant_prune_LM_head in [False, True]: - self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) - missing_token = next( - token - for token in self.target_tokenizer.get_vocab() - if token not in self.assistant_tokenizer.get_vocab() - and token not in self.target_tokenizer.all_special_tokens - and "reserved_" not in token - ) - input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) - self.generator.input_ids = input_ids - candidates, _ = self.generator.get_candidates(input_ids) - self.assertIsNotNone(candidates) + missing_token = next( + token + for token in self.target_tokenizer.get_vocab() + if token not in self.assistant_tokenizer.get_vocab() + and token not in self.target_tokenizer.all_special_tokens + and "reserved_" not in token + ) + input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) + self.generator.input_ids = input_ids + candidates, _ = self.generator.get_candidates(input_ids) + self.assertIsNotNone(candidates) def test_speculation_depth(self): """Test different speculation depths""" - for assistant_prune_LM_head in [False, True]: - self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) - input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt") - self.generator.input_ids = input_ids + input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt") + self.generator.input_ids = input_ids - for depth in [1, 8, 17]: - self.generator.num_assistant_tokens = depth - candidates, _ = self.generator.get_candidates(input_ids) - self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) + for depth in [1, 8, 17]: + self.generator.num_assistant_tokens = depth + candidates, _ = self.generator.get_candidates(input_ids) + self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) def test_device_consistency(self): """Test handling of inputs on different devices""" - for assistant_prune_LM_head in [False, True]: - self.setUpGenerator(assistant_prune_LM_head=assistant_prune_LM_head) - input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) - self.generator.input_ids = input_ids - candidates, _ = self.generator.get_candidates(input_ids) - self.assertEqual(candidates.device, input_ids.device) + input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) + self.generator.input_ids = input_ids + candidates, _ = self.generator.get_candidates(input_ids) + self.assertEqual(candidates.device, input_ids.device) def test_usd_vs_vanilla_sampling(cls): """Test that USD matches vanilla sampling with temperature set to nearly 0""" @@ -338,17 +327,13 @@ def test_usd_vs_vanilla_sampling(cls): pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False) vanilla_text = pipe_vanilla_output[0]["generated_text"] - for assistant_prune_LM_head in [False, True]: - pipe_usd = pipeline( - "text-generation", - model=cls.target_name, - assistant_model=cls.assistant_name, - assistant_prune_LM_head=assistant_prune_LM_head, - ) - pipe_usd_output = pipe_usd( - prompt, max_new_tokens=5, do_sample=True, temperature=1e-9 - ) # Nearly 0 temperature - usd_text = pipe_usd_output[0]["generated_text"] + pipe_usd = pipeline( + "text-generation", + model=cls.target_name, + assistant_model=cls.assistant_name, + ) + pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature + usd_text = pipe_usd_output[0]["generated_text"] - # Assert that the outputs match - cls.assertEqual(usd_text, vanilla_text) + # Assert that the outputs match + cls.assertEqual(usd_text, vanilla_text) From f57ae18342234f64242a6f0aa67260686209cfcb Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 2 Apr 2025 23:19:14 -0700 Subject: [PATCH 08/16] address Joao's comments --- .../generation/candidate_generator.py | 38 +++++++++++-------- src/transformers/generation/utils.py | 2 +- tests/generation/test_candidate_generator.py | 12 +++--- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index adce08462640..4cdc84f1fe10 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -38,6 +38,8 @@ from ..tokenization_utils_base import PreTrainedTokenizerBase from .configuration_utils import GenerationConfig +from transformers.utils import deprecate_kwarg + class CandidateGenerator: """Abstract base class for all candidate generators that can be applied during assisted generation.""" @@ -614,7 +616,7 @@ def _process_assistant_outputs( return new_target_ids -class PruneReindexingLMHead(nn.Module): +class _PruneReindexingLMHead(nn.Module): """ A class to prune and reindex the language model head. @@ -637,7 +639,7 @@ def forward(self, hidden_states): return pruned_logits -class MapInputEmbedding(nn.Module): +class _MapInputEmbeddingg(nn.Module): def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids): """ Wraps an existing embedding layer and remaps token IDs before lookup. @@ -693,13 +695,17 @@ class AssistantToTargetTranslator: FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping. + @deprecate_kwarg("assistant_model_device", version="4.53") def __init__( self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model: "PreTrainedModel", - target_vocab_size: Optional[int], - assistant_prune_LM_head: Optional[bool] = True, + target_vocab_size: Optional[ + int + ], # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() + assistant_model_device: str = "cpu", + assistant_model: "PreTrainedModel" = None, + assistant_prune_lm_head: Optional[bool] = True, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer @@ -710,22 +716,22 @@ def __init__( ) self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: Optional[LogitsProcessorList] = None - self.assistant_prune_LM_head = assistant_prune_LM_head + self.assistant_prune_lm_head = assistant_prune_lm_head if len(self._suppress_input_ids) > 0: # the assistant vocab is not a subset of the target vocab - if self.assistant_prune_LM_head: + if self.assistant_prune_lm_head: self.assistant_overlap_token_ids = torch.tensor( list(self.target_to_assistant_input_ids.values()), dtype=torch.long, device=self._assistant_model_device, ) original_lm_head = assistant_model.get_output_embeddings() - pruned_lm_head = PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids) + pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids) del original_lm_head assistant_model.set_output_embeddings(pruned_lm_head) original_input_embeddings = assistant_model.get_input_embeddings() - map_input_embeddings = MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids) + map_input_embeddings = _MapInputEmbeddingg(original_input_embeddings, self.assistant_overlap_token_ids) del original_input_embeddings assistant_model.set_input_embeddings(map_input_embeddings) self.map_input_embeddings = map_input_embeddings @@ -735,7 +741,7 @@ def __init__( ) def set_unmap(self): - if self.assistant_prune_LM_head: + if self.assistant_prune_lm_head: self.map_input_embeddings.map = False def _get_assistant_to_target_input_ids(self): @@ -794,7 +800,7 @@ def get_target_ids( else: # Get last `num_new_tokens` candidate IDs last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:] - if self.assistant_prune_LM_head: + if self.assistant_prune_lm_head: # Map assistant IDs -> target input IDs last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids] transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids] @@ -812,7 +818,7 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT # Exclude invalid indices target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] - if self.assistant_prune_LM_head: + if self.assistant_prune_lm_head: target_logits[..., target_logits_supported_indices] = assistant_logits else: valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] @@ -829,13 +835,15 @@ class AssistantVocabTranslatorCache: _cache = weakref.WeakKeyDictionary() @classmethod + @deprecate_kwarg("assistant_model_device", version="4.53") def get_translator( cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model: "PreTrainedModel", target_vocab_size: Optional[int] = None, - assistant_prune_LM_head=True, + assistant_model_device: str = "cpu", + assistant_model: "PreTrainedModel" = None, + assistant_prune_lm_head=True, ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) if assistant_dict is None: @@ -845,7 +853,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, assistant_model, target_vocab_size, assistant_prune_LM_head + target_tokenizer, assistant_tokenizer, assistant_model, target_vocab_size, assistant_prune_lm_head ) assistant_dict[assistant_tokenizer] = mapping diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 274f2cad34a8..9e6c71fac3d6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -881,7 +881,7 @@ def _get_candidate_generator( assistant_tokenizer, assistant_model, self.config.vocab_size, - assistant_prune_LM_head=True, # prune LM head of assistant model + assistant_prune_lm_head=True, # prune LM head of assistant model ) # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index assistant_model.generation_config.repetition_penalty = None diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 1c2839e9c5cd..4da67721236e 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -36,7 +36,7 @@ def setUp(self): assistant_tokenizer=self.assistant_tokenizer, assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, - assistant_prune_LM_head=False, + assistant_prune_lm_head=False, ) def test_get_assistant_to_target_input_ids(self): @@ -131,14 +131,14 @@ def test_same_instance_for_same_tokenizers(self): self.assistant_tokenizer, assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, - assistant_prune_LM_head=False, + assistant_prune_lm_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, - assistant_prune_LM_head=False, + assistant_prune_lm_head=False, ) self.assertIs(translator1, translator2, "Translators should be cached and identical") @@ -149,14 +149,14 @@ def test_different_instances_for_different_tokenizers(self): self.assistant_tokenizer, assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, - assistant_prune_LM_head=False, + assistant_prune_lm_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.other_target_tokenizer, self.other_assistant_tokenizer, assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, - assistant_prune_LM_head=False, + assistant_prune_lm_head=False, ) self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") @@ -172,7 +172,7 @@ def test_cache_with_weakref_key(self): assistant_tokenizer, assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, - assistant_prune_LM_head=False, + assistant_prune_lm_head=False, ) self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) From ed10e3ce42dac142fd8327aeae0db834e12cd592 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 3 Apr 2025 04:20:45 -0700 Subject: [PATCH 09/16] deprecate_kwarg --- src/transformers/generation/candidate_generator.py | 13 +++++-------- src/transformers/generation/utils.py | 2 +- tests/generation/test_candidate_generator.py | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 65c3903be8e1..74c5d7982ffb 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -38,8 +38,7 @@ from ..tokenization_utils_base import PreTrainedTokenizerBase from .configuration_utils import GenerationConfig -from transformers.utils import deprecate_kwarg - +from ..utils.deprecation import deprecate_kwarg class CandidateGenerator: """Abstract base class for all candidate generators that can be applied during assisted generation.""" @@ -703,8 +702,7 @@ def __init__( target_vocab_size: Optional[ int ], # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() - assistant_model_device: str = "cpu", - assistant_model: "PreTrainedModel" = None, + assistant_model: "PreTrainedModel", assistant_prune_lm_head: Optional[bool] = True, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer @@ -842,9 +840,8 @@ def get_translator( cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - target_vocab_size: Optional[int] = None, - assistant_model_device: str = "cpu", - assistant_model: "PreTrainedModel" = None, + target_vocab_size: Optional[int], + assistant_model: "PreTrainedModel", assistant_prune_lm_head=True, ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) @@ -855,7 +852,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, assistant_model, target_vocab_size, assistant_prune_lm_head + target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model, assistant_prune_lm_head ) assistant_dict[assistant_tokenizer] = mapping diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 190cd059ffd0..69b6fc2ba817 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -964,8 +964,8 @@ def _get_candidate_generator( atm_translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer, assistant_tokenizer, - assistant_model, self.config.vocab_size, + assistant_model, assistant_prune_lm_head=True, # prune LM head of assistant model ) # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 4da67721236e..011fc3bbb08c 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -199,9 +199,9 @@ def create_translator(): translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer, assistant_tokenizer, - assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, - assistant_prune_LM_head=False, + assistant_model=self.assistant_model, + assistant_prune_lm_head=False, ) # Create weak references before returning refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer)) From 98965b2227bc477315ce9e9670ddb6392da36743 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 3 Apr 2025 04:49:27 -0700 Subject: [PATCH 10/16] add doc --- .../generation/candidate_generator.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 74c5d7982ffb..bdfd0efe8e4b 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -40,6 +40,7 @@ from ..utils.deprecation import deprecate_kwarg + class CandidateGenerator: """Abstract base class for all candidate generators that can be applied during assisted generation.""" @@ -638,7 +639,7 @@ def forward(self, hidden_states): return pruned_logits -class _MapInputEmbeddingg(nn.Module): +class _MapInputEmbedding(nn.Module): def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids): """ Wraps an existing embedding layer and remaps token IDs before lookup. @@ -729,7 +730,7 @@ def __init__( assistant_model.set_output_embeddings(pruned_lm_head) original_input_embeddings = assistant_model.get_input_embeddings() - map_input_embeddings = _MapInputEmbeddingg(original_input_embeddings, self.assistant_overlap_token_ids) + map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids) del original_input_embeddings assistant_model.set_input_embeddings(map_input_embeddings) self.map_input_embeddings = map_input_embeddings @@ -738,7 +739,13 @@ def __init__( [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] ) - def set_unmap(self): + def unmap_input_ids(self): + """ + Disables the mapping of input ids despite the assistant pruning for the language model head being enabled. + + This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping. + + """ if self.assistant_prune_lm_head: self.map_input_embeddings.map = False @@ -989,7 +996,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) assistant_input_ids = assistant_input_ids.to(dtype=torch.long) - self._atm_translator.set_unmap() + self._atm_translator.unmap_input_ids() return assistant_input_ids, len(assistant_new_ids[0]) From 77a4b0b49c33ff1614be89d59294de8fcfc263b2 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 3 Apr 2025 05:08:23 -0700 Subject: [PATCH 11/16] fix target_vocab_size --- src/transformers/generation/candidate_generator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index bdfd0efe8e4b..c43f314b539b 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -700,9 +700,7 @@ def __init__( self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - target_vocab_size: Optional[ - int - ], # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() + target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() assistant_model: "PreTrainedModel", assistant_prune_lm_head: Optional[bool] = True, ): @@ -847,7 +845,7 @@ def get_translator( cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - target_vocab_size: Optional[int], + target_vocab_size: int, assistant_model: "PreTrainedModel", assistant_prune_lm_head=True, ) -> AssistantToTargetTranslator: From acf7c88307b1e44313c22cd7c2a6e0ade996a011 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Mon, 7 Apr 2025 11:09:50 +0300 Subject: [PATCH 12/16] Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante --- src/transformers/generation/candidate_generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index c43f314b539b..d27d5e1a3518 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -701,6 +701,7 @@ def __init__( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() + assistant_model_device: str = "cpu", assistant_model: "PreTrainedModel", assistant_prune_lm_head: Optional[bool] = True, ): From ac833ced5383f646a62ae0e06f0afb84e8cd7c23 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Mon, 7 Apr 2025 11:10:00 +0300 Subject: [PATCH 13/16] Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index d27d5e1a3518..88bde1f79386 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -707,7 +707,7 @@ def __init__( ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device: str = assistant_model.device + self._assistant_model_device: str = assistant_model.device or assistant_model_device self.target_vocab_size: int = target_vocab_size self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( self._get_assistant_to_target_input_ids() From 9d2a1e3769f7041fbfdd25b8186c748bd54eddb8 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Mon, 7 Apr 2025 11:10:21 +0300 Subject: [PATCH 14/16] Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante --- src/transformers/generation/candidate_generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 88bde1f79386..f823d768fba6 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -847,6 +847,7 @@ def get_translator( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, + assistant_model_device: str = "cpu", assistant_model: "PreTrainedModel", assistant_prune_lm_head=True, ) -> AssistantToTargetTranslator: From d343966bedd889056c6e722dbc64e7db7ae7b155 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Mon, 7 Apr 2025 11:10:34 +0300 Subject: [PATCH 15/16] Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index f823d768fba6..e1729f341676 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -859,7 +859,7 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model, assistant_prune_lm_head + target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device, assistant_model, assistant_prune_lm_head ) assistant_dict[assistant_tokenizer] = mapping From c9e844f163bd635705159c4b127a50e7ce873400 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 8 Apr 2025 07:08:38 -0700 Subject: [PATCH 16/16] fix deprecated argument assistant_model_device --- .../generation/candidate_generator.py | 32 +++++++++++++------ src/transformers/generation/utils.py | 2 +- tests/generation/test_candidate_generator.py | 12 +++---- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e1729f341676..3425a0234b42 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -686,10 +686,15 @@ class AssistantToTargetTranslator: The tokenizer used by the target (main) model. assistant_tokenizer (`PreTrainedTokenizerBase`): The tokenizer used by the assistant model. - assistant_model_device (`str`, defaults to "cpu"): - The device where the assistant model is located. Used for placing tensors. - target_vocab_size (`int`, *optional*): + target_vocab_size (`int`): The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. + assistant_model_device (str, optional): The device on which the assistant model is loaded. + Defaults to "cpu". + assistant_model_device (`str`, defaults to "cpu"): The device where the assistant model is located. Used for placing tensors. + assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility. + assistant_prune_lm_head (bool): Whether to prune the assistant model's language model + head to match the target vocabulary. This is only applicable if `assistant_model` is provided. + Defaults to False for backward compatibility. """ FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. @@ -702,19 +707,21 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() assistant_model_device: str = "cpu", - assistant_model: "PreTrainedModel", - assistant_prune_lm_head: Optional[bool] = True, + assistant_model: Optional["PreTrainedModel"] = None, + assistant_prune_lm_head: bool = False, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device: str = assistant_model.device or assistant_model_device + self._assistant_model_device: str = ( + assistant_model_device if assistant_model is None else assistant_model.device + ) self.target_vocab_size: int = target_vocab_size self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( self._get_assistant_to_target_input_ids() ) self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: Optional[LogitsProcessorList] = None - self.assistant_prune_lm_head = assistant_prune_lm_head + self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None if len(self._suppress_input_ids) > 0: # the assistant vocab is not a subset of the target vocab if self.assistant_prune_lm_head: @@ -848,8 +855,8 @@ def get_translator( assistant_tokenizer: "PreTrainedTokenizerBase", target_vocab_size: int, assistant_model_device: str = "cpu", - assistant_model: "PreTrainedModel", - assistant_prune_lm_head=True, + assistant_model: Optional["PreTrainedModel"] = None, + assistant_prune_lm_head: bool = False, ) -> AssistantToTargetTranslator: assistant_dict = cls._cache.get(target_tokenizer) if assistant_dict is None: @@ -859,7 +866,12 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device, assistant_model, assistant_prune_lm_head + target_tokenizer, + assistant_tokenizer, + target_vocab_size, + assistant_model_device, + assistant_model, + assistant_prune_lm_head, ) assistant_dict[assistant_tokenizer] = mapping diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 302da7134dc0..c23d595084f7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -965,7 +965,7 @@ def _get_candidate_generator( target_tokenizer, assistant_tokenizer, self.config.vocab_size, - assistant_model, + assistant_model=assistant_model, assistant_prune_lm_head=True, # prune LM head of assistant model ) # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismaches between token ids and logits index diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 011fc3bbb08c..3a50a963a9a2 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -34,8 +34,8 @@ def setUp(self): self.translator = AssistantToTargetTranslator( target_tokenizer=self.target_tokenizer, assistant_tokenizer=self.assistant_tokenizer, - assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, assistant_prune_lm_head=False, ) @@ -129,15 +129,15 @@ def test_same_instance_for_same_tokenizers(self): translator1 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, assistant_prune_lm_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, assistant_prune_lm_head=False, ) self.assertIs(translator1, translator2, "Translators should be cached and identical") @@ -147,15 +147,15 @@ def test_different_instances_for_different_tokenizers(self): translator1 = AssistantVocabTranslatorCache.get_translator( self.target_tokenizer, self.assistant_tokenizer, - assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, assistant_prune_lm_head=False, ) translator2 = AssistantVocabTranslatorCache.get_translator( self.other_target_tokenizer, self.other_assistant_tokenizer, - assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, assistant_prune_lm_head=False, ) self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") @@ -170,8 +170,8 @@ def test_cache_with_weakref_key(self): translator = AssistantVocabTranslatorCache.get_translator( target_tokenizer, assistant_tokenizer, - assistant_model=self.assistant_model, target_vocab_size=self.target_vocab_size, + assistant_model=self.assistant_model, assistant_prune_lm_head=False, ) self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)