diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 25148b3a6b83..a37481018086 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -24,7 +24,6 @@ from ..cache_utils import DynamicCache from ..pytorch_utils import isin_mps_friendly from .logits_process import ( - LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor, @@ -245,18 +244,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]: min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) return min_new_tokens, max_new_tokens - def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: + def _update_past_and_masks( + self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 + ) -> bool: """Update past key values and attention masks for subsequent generation rounds.""" has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens ) self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) + return has_past_key_values def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: @@ -565,34 +567,41 @@ class AssistantToTargetTranslator: Translate the assistant into the target universe. """ - def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase"): + def __init__( + self, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + assistant_model_device, + target_vocab_size: int, + filter_value: float = -float("Inf"), + suppress_tokens_id: int = -1, + ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_to_target_input_ids: dict[int, int] = self._get_assistant_to_target_input_ids() - self.suppress_input_ids: list[int] = self._get_suppress_input_ids() + self._assistant_model_device = assistant_model_device + self.target_vocab_size: int = target_vocab_size + self.filter_value: float = filter_value + self.suppress_tokens_id: int = suppress_tokens_id + self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() self.logits_processors: LogitsProcessorList = LogitsProcessorList( - [ - SuppressTokensLogitsProcessor(self.suppress_input_ids), - LogitNormalization(), - ] + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] ) - def _get_assistant_to_target_input_ids(self) -> dict[int, int]: - """ - Get a mapping from assistant tokens to target tokens based on vocabularies. - """ + def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() assistant_vocab = self._assistant_tokenizer.get_vocab() - return { - assistant_vocab[tok]: target_vocab[tok] for tok in set(target_vocab.keys()) & set(assistant_vocab.keys()) - } + max_assistant_index = max(assistant_vocab.values()) + assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int) + for tok, idx in assistant_vocab.items(): + if tok in target_vocab: + assistant_to_target_input_ids[idx] = target_vocab[tok] + return assistant_to_target_input_ids.to(self._assistant_model_device) def _get_suppress_input_ids(self) -> list[int]: """ Get the input ids that are in the assistant vocab but not in the target vocab. """ - assistant_vocab = self._assistant_tokenizer.get_vocab() - return list(set(assistant_vocab.values()) - set(self._assistant_to_target_input_ids.keys())) + return torch.where(self._assistant_to_target_input_ids == self.suppress_tokens_id)[0] def get_target_ids( self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor @@ -602,33 +611,29 @@ def get_target_ids( Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. """ - device = assistant_candidate_ids.device - target_candidate_ids = ( - assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :] - .cpu() - .apply_(lambda x: self._assistant_to_target_input_ids.get(x, x)) - .to(device) - ) - return torch.cat((target_input_ids, target_candidate_ids.unsqueeze(0)), dim=1) + + num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] + if num_new_tokens == 0: + return target_input_ids + else: + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] + return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: """ Return the target logits that correspond to the assistant logits. """ - device = assistant_logits.device - target_vocab_size: int = len(self._target_tokenizer.get_vocab()) - target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], target_vocab_size) - target_logits: torch.FloatTensor = torch.full(target_shape, -float("inf")).to(device) - assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") - assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ - -1 - ] - target_logits_supported_indices: torch.IntTensor = ( - assistant_logits_supported_indices.cpu() - .apply_(lambda x: self._assistant_to_target_input_ids[x]) - .to(device) - ) - target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] + + target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) + target_logits: torch.FloatTensor = torch.full(target_shape, self.filter_value).to(self._assistant_model_device) + # Mask for valid indices + assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id + # Exclude invalid indices + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] + + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] + return target_logits @@ -643,7 +648,11 @@ class AssistantVocabTranslatorCache: @classmethod def get_translator( - cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase" + cls, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + assistant_model_device, + target_vocab_size: int, ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) @@ -653,7 +662,9 @@ def get_translator( mapping = assistant_dict.get(assistant_tokenizer) if mapping is None: - mapping = AssistantToTargetTranslator(target_tokenizer, assistant_tokenizer) + mapping = AssistantToTargetTranslator( + target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size + ) assistant_dict[assistant_tokenizer] = mapping return mapping @@ -692,11 +703,14 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", generation_config: "GenerationConfig", model_kwargs: Dict, + target_vocab_size: int, inputs_tensor: Optional[torch.Tensor] = None, logits_processor: "LogitsProcessorList" = None, ): # Initialize translator before parent class - self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer) + self._atm_translator = AssistantVocabTranslatorCache.get_translator( + target_tokenizer, assistant_tokenizer, assistant_model.device, target_vocab_size + ) super().__init__( input_ids, assistant_model, @@ -708,42 +722,49 @@ def __init__( logits_processor, ) # Track sequence lengths and previous assistant IDs - self._prev_target_seq_len: int = 0 + self._target_seq_len_with_candidates: int = 0 self._prev_assistant_ids: Optional[torch.LongTensor] = None + self.target_vocab_size = target_vocab_size def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Simplified version of get_candidates that uses the translator cache for token conversion. """ - input_ids = input_ids.to(self.assistant_model.device) - target_input_ids = input_ids.clone() - assistant_input_ids = self._prepare_assistant_input_ids(target_input_ids) + target_input_ids = input_ids.to(self.assistant_model.device) + assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids) min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) if max_new_tokens == 0: return input_ids, None - self._update_past_and_masks(assistant_input_ids) + self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) - self.assistant_kwargs.pop("attention_mask", None) # Ensure scores are returned generation_args["generation_config"].output_scores = True generation_args["generation_config"].return_dict_in_generate = True # Generate and process outputs using translator - assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) - self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - - candidate_logits = torch.stack(assistant_output.scores, dim=1) + generation_args["logits_processor"] = self._atm_translator.logits_processors + self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) # Use translator to convert tokens and logits - candidate_ids = assistant_output.sequences - candidate_logits = self._atm_translator.logits_processors(input_ids=candidate_ids, scores=candidate_logits) - target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids) - target_logits = self._atm_translator.get_target_logits(candidate_logits) + target_candidate_ids = self._atm_translator.get_target_ids( + assistant_input_ids, target_input_ids, self._prev_assistant_ids + ) + self._target_seq_len_with_candidates = target_candidate_ids.shape[-1] + target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits) - return target_ids, target_logits + return target_candidate_ids, target_candidate_logits + + def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool: + if self._prev_assistant_ids is None: + # Prepare attention mask for the first generation. + # For subsequent generations, the attention mask is updated in super()_update_past_and_masks. + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) + return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: """ @@ -751,9 +772,11 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to """ # Calculate new tokens since last call target_seq_len = target_input_ids.shape[-1] - new_token_count = target_seq_len - self._prev_target_seq_len + if self._target_seq_len_with_candidates == 0: + new_token_count = target_seq_len + else: + new_token_count = 1 target_new_ids = target_input_ids[:, -new_token_count:] - self._prev_target_seq_len = target_seq_len # Convert only the new tokens target_new_text = self.target_tokenizer.batch_decode( @@ -765,11 +788,16 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to # Update or initialize assistant IDs if self._prev_assistant_ids is None: - self._prev_assistant_ids = assistant_new_ids + assistant_input_ids = assistant_new_ids else: - self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) - - return self._prev_assistant_ids + tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len + # If the number of new tokens is greater than zero, truncate the previous assistant IDs + if tokens_to_remove > 0: + self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] + assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) + assistant_input_ids = assistant_input_ids.to(torch.int) + + return assistant_input_ids, len(assistant_new_ids[0]) class PromptLookupCandidateGenerator(CandidateGenerator): diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 5fcd35c921af..e818b266cd7b 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1860,14 +1860,15 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, suppress_tokens, device: str = "cpu"): + def __init__(self, suppress_tokens, device: str = "cpu", filter_value: float = -float("Inf")): self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device) + self.filter_value = filter_value @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device)) - scores = torch.where(suppress_token_mask, -float("inf"), scores) + suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens) + scores = torch.where(suppress_token_mask, self.filter_value, scores) return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7a9d78168ac9..d7d9757d3e4f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -858,6 +858,8 @@ def _get_candidate_generator( logits_processor=logits_processor, target_tokenizer=target_tokenizer, assistant_tokenizer=assistant_tokenizer, + # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() + target_vocab_size=self.config.vocab_size, ) case False: candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 260f92109b7b..dd7e427a3bfd 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -79,7 +79,7 @@ def test_get_assistant_to_target_input_ids(self): def test_get_suppress_input_ids(self): """Test the suppression of assistant input IDs not present in the target vocabulary.""" expected_suppress_ids = [4] - actual_suppress_ids = self.translator.suppress_input_ids + actual_suppress_ids = self.translator._suppress_input_ids self.assertEqual(actual_suppress_ids, expected_suppress_ids) def test_get_target_ids(self):