From dc601db78c7981350bc9ed8f3bfef9c32a98ad20 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 23 Dec 2024 04:26:17 -0800 Subject: [PATCH 1/3] default values for AssistantToTargetTranslator fileds --- .../generation/candidate_generator.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 26c81da0c1a5..73de769af061 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -627,15 +627,18 @@ def __init__( self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model_device, - target_vocab_size: int, + assistant_model_device:str = "cpu", + target_vocab_size: int = None, filter_value: float = -float("Inf"), suppress_tokens_id: int = -1, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device = assistant_model_device - self.target_vocab_size: int = target_vocab_size + self._assistant_model_device:str = assistant_model_device + if target_vocab_size: + self.target_vocab_size: int = target_vocab_size + else: + self.target_vocab_size: int = len(self._target_tokenizer.get_vocab()) self.filter_value: float = filter_value self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() @@ -707,8 +710,8 @@ def get_translator( cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model_device, - target_vocab_size: int, + assistant_model_device: str = "cpu", + target_vocab_size: int = None, ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) From 3fc2c63d3bb6893d5453deadd3fc8f91ed3a362f Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 24 Dec 2024 04:26:11 -0800 Subject: [PATCH 2/3] fix --- src/transformers/generation/candidate_generator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 73de769af061..185ccc500e09 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -627,14 +627,14 @@ def __init__( self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model_device:str = "cpu", - target_vocab_size: int = None, + assistant_model_device: str = "cpu", + target_vocab_size: Optional[int] = None, filter_value: float = -float("Inf"), suppress_tokens_id: int = -1, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device:str = assistant_model_device + self._assistant_model_device: str = assistant_model_device if target_vocab_size: self.target_vocab_size: int = target_vocab_size else: @@ -711,7 +711,7 @@ def get_translator( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device: str = "cpu", - target_vocab_size: int = None, + target_vocab_size: Optional[int] = None, ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) From 060645385cdf77e30decd738908606528d7e8578 Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 26 Dec 2024 03:48:10 -0800 Subject: [PATCH 3/3] add support to empty logit_processors --- src/transformers/generation/candidate_generator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 185ccc500e09..d1bff45049cb 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -642,9 +642,13 @@ def __init__( self.filter_value: float = filter_value self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() - self.logits_processors: LogitsProcessorList = LogitsProcessorList( - [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] - ) + self._suppress_input_ids: list[int] = self._get_suppress_input_ids() + self.logits_processors: Optional[LogitsProcessorList] = None + if len(self._suppress_input_ids) > 0: + # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab + self.logits_processors = LogitsProcessorList( + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] + ) def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() @@ -804,7 +808,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, generation_args["generation_config"].return_dict_in_generate = True # Generate and process outputs using translator - generation_args["logits_processor"] = self._atm_translator.logits_processors + if self._atm_translator.logits_processors is not None: + generation_args["logits_processor"] = self._atm_translator.logits_processors self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) # Use translator to convert tokens and logits