diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index d2a246c81f03..980505e8979d 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -641,7 +641,7 @@ def __init__( self.target_vocab_size: int = len(self._target_tokenizer.get_vocab()) self.filter_value: float = filter_value self.suppress_tokens_id: int = suppress_tokens_id - self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() + self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = self._get_assistant_to_target_input_ids() self._suppress_input_ids: list[int] = self._get_suppress_input_ids() self.logits_processors: Optional[LogitsProcessorList] = None if len(self._suppress_input_ids) > 0: @@ -677,10 +677,13 @@ def _get_assistant_to_target_input_ids(self): max_assistant_index = max(assistant_vocab.values()) assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int) - for tok, idx in assistant_vocab.items(): - if tok in target_vocab: - assistant_to_target_input_ids[idx] = target_vocab[tok] - return assistant_to_target_input_ids.to(self._assistant_model_device) + target_to_assistant_input_id: Dict[int, int] = {} + for tok, assistant_id in assistant_vocab.items(): + target_id = target_vocab.get(tok) + if target_id is not None: + assistant_to_target_input_ids[assistant_id] = target_id + target_to_assistant_input_ids[target_id] = assistant_id + return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids def _get_suppress_input_ids(self) -> list[int]: """ @@ -864,13 +867,20 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to new_token_count = 1 target_new_ids = target_input_ids[:, -new_token_count:] - # Convert only the new tokens - target_new_text = self.target_tokenizer.batch_decode( - target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - assistant_new_ids = self.assistant_tokenizer(target_new_text, add_special_tokens=False, return_tensors="pt")[ - "input_ids" - ].to(self.assistant_model.device) + # Convert the new tokens + assistant_new_ids = None + if self._target_seq_len_with_candidates > 0: + # we have only one new token and we can directly convert it + assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item()) + if assistant_new_ids is None: + target_new_text = self.target_tokenizer.batch_decode( + target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + assistant_new_ids = self.assistant_tokenizer( + target_new_text, add_special_tokens=False, return_tensors="pt" + )["input_ids"].to(self.assistant_model.device) + else: + assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device) # Update or initialize assistant IDs if self._prev_assistant_ids is None: