diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 26c81da0c1a5..5ba09a9d518d 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -627,15 +627,18 @@ def __init__( self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model_device, - target_vocab_size: int, + assistant_model_device: str = "cpu", + target_vocab_size: int = None, filter_value: float = -float("Inf"), suppress_tokens_id: int = -1, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device = assistant_model_device - self.target_vocab_size: int = target_vocab_size + self._assistant_model_device: str = assistant_model_device + if target_vocab_size: + self.target_vocab_size: int = target_vocab_size + else: + self.target_vocab_size: int = len(self._target_tokenizer.get_vocab()) self.filter_value: float = filter_value self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() @@ -646,6 +649,28 @@ def __init__( def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() assistant_vocab = self._assistant_tokenizer.get_vocab() + + space_str = " " + target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"] + if len(target_space_ids) > 0: + target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0] + + assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"] + if len(assistant_space_ids) > 0: + assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0] + + if target_space_sign != assistant_space_sign: + # If the assistant tokenizer has a different space sign than the target tokenizer, + # we need to replace the assistant space sign with the target space sign in the assistant_vocab. + assistant_vocab = { + ( + tok.replace(assistant_space_sign, target_space_sign, 1) + if tok.startswith(assistant_space_sign) + else tok + ): idx + for tok, idx in assistant_vocab.items() + } + max_assistant_index = max(assistant_vocab.values()) assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int) for tok, idx in assistant_vocab.items(): @@ -707,8 +732,8 @@ def get_translator( cls, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model_device, - target_vocab_size: int, + assistant_model_device: str = "cpu", + target_vocab_size: int = None, ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 2fcda92a9e1b..7c65f3697425 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -129,6 +129,12 @@ def __init__(self, vocab=None): def get_vocab(self): return self._vocab + def __call__(self, text, add_special_tokens=True): + # Mock implementation of the __call__ method + tokens = text.split() + input_ids = [self._vocab.get(token, 0) for token in tokens] + return {"input_ids": input_ids} + class TestAssistantVocabTranslatorCache(unittest.TestCase): def setUp(self):