diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 294e92e3293e..a36280ea2244 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -14,8 +14,6 @@ # limitations under the License. import copy -import threading -import weakref from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import numpy as np @@ -29,10 +27,10 @@ from ..cache_utils import DynamicCache from ..pytorch_utils import isin_mps_friendly +from ..utils import AssistantToTargetTranslator from .logits_process import ( LogitsProcessorList, MinLengthLogitsProcessor, - SuppressTokensLogitsProcessor, ) @@ -616,181 +614,6 @@ def _process_assistant_outputs( return new_target_ids - -class AssistantToTargetTranslator: - """ - Translates token ids and logits between assistant and target model vocabularies. This class is used to handle - vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding, - as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies" - (https://www.arxiv.org/abs/2502.05202). - It maintains mappings between the two vocabularies and handles token/logit conversion. - - Args: - target_tokenizer (`PreTrainedTokenizerBase`): - The tokenizer used by the target (main) model. - assistant_tokenizer (`PreTrainedTokenizerBase`): - The tokenizer used by the assistant model. - assistant_model_device (`str`, defaults to "cpu"): - The device where the assistant model is located. Used for placing tensors. - target_vocab_size (`int`, *optional*): - The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. - """ - - FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. - SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping. - - def __init__( - self, - target_tokenizer: "PreTrainedTokenizerBase", - assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model_device: str = "cpu", - target_vocab_size: Optional[int] = None, - ): - self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer - self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer - self._assistant_model_device: str = assistant_model_device - if target_vocab_size is None: - self.target_vocab_size: int = len(self._target_tokenizer.get_vocab()) - else: - self.target_vocab_size: int = target_vocab_size - self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( - self._get_assistant_to_target_input_ids() - ) - self._suppress_input_ids: list[int] = self._get_suppress_input_ids() - self.logits_processors: Optional[LogitsProcessorList] = None - if len(self._suppress_input_ids) > 0: - # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab - self.logits_processors = LogitsProcessorList( - [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] - ) - - def _get_assistant_to_target_input_ids(self): - target_vocab = self._target_tokenizer.get_vocab() - assistant_vocab = self._assistant_tokenizer.get_vocab() - - space_str = " " - target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"] - if len(target_space_ids) > 0: - target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0] - - assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"] - if len(assistant_space_ids) > 0: - assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0] - - if target_space_sign != assistant_space_sign: - # If the assistant tokenizer has a different space sign than the target tokenizer, - # we need to replace the assistant space sign with the target space sign in the assistant_vocab. - assistant_vocab = { - ( - tok.replace(assistant_space_sign, target_space_sign, 1) - if tok.startswith(assistant_space_sign) - else tok - ): idx - for tok, idx in assistant_vocab.items() - } - - max_assistant_index = max(assistant_vocab.values()) - assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int) - target_to_assistant_input_ids: Dict[int, int] = {} - for tok, assistant_id in assistant_vocab.items(): - target_id = target_vocab.get(tok) - if target_id is not None: - assistant_to_target_input_ids[assistant_id] = target_id - target_to_assistant_input_ids[target_id] = assistant_id - return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids - - def _get_suppress_input_ids(self) -> list[int]: - """ - Get the input ids that are in the assistant vocab but not in the target vocab. - """ - return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0] - - def get_target_ids( - self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor - ) -> torch.LongTensor: - """ - Return the target candidate ids that correspond to the assistant candidate ids. - Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. - Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. - """ - - num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] - if num_new_tokens == 0: - return target_input_ids - else: - transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] - return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) - - def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: - """ - Return the target logits that correspond to the assistant logits. - """ - - target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) - target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device) - # Mask for valid indices - assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID - # Exclude invalid indices - target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] - valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] - - target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] - - return target_logits - - -class AssistantVocabTranslatorCache: - """ - Cache for `AssistantToTargetTranslator` instances. The instances are computed at - pre-processing time, and this cache allows us to avoid recomputing them. - """ - - _lock = threading.Lock() - _cache = weakref.WeakKeyDictionary() - - @classmethod - def get_translator( - cls, - target_tokenizer: "PreTrainedTokenizerBase", - assistant_tokenizer: "PreTrainedTokenizerBase", - assistant_model_device: str = "cpu", - target_vocab_size: Optional[int] = None, - ) -> AssistantToTargetTranslator: - with cls._lock: - assistant_dict = cls._cache.get(target_tokenizer) - if assistant_dict is None: - assistant_dict = weakref.WeakKeyDictionary() - cls._cache[target_tokenizer] = assistant_dict - - mapping = assistant_dict.get(assistant_tokenizer) - if mapping is None: - mapping = AssistantToTargetTranslator( - target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size - ) - assistant_dict[assistant_tokenizer] = mapping - - return mapping - - @classmethod - def cleanup(cls): - """ - Clean up dead references in the cache. - This removes entries where either the target_tokenizer or assistant_tokenizer - has been garbage collected. - """ - with cls._lock: - # Remove entries from the outer cache where the target_tokenizer is no longer alive - dead_keys = [key for key in cls._cache if key is None] - for key in dead_keys: - del cls._cache[key] - - # For each assistant_dict, remove entries where assistant_tokenizer is no longer alive - for assistant_dict in cls._cache.values(): - dead_keys = [key for key in assistant_dict if key is None] - for key in dead_keys: - del assistant_dict[key] - - class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers): """ `CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers @@ -805,14 +628,12 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", generation_config: "GenerationConfig", model_kwargs: Dict, - target_vocab_size: int, + atm_translator: AssistantToTargetTranslator, inputs_tensor: Optional[torch.Tensor] = None, logits_processor: "LogitsProcessorList" = None, ): # Initialize translator before parent class - self._atm_translator = AssistantVocabTranslatorCache.get_translator( - target_tokenizer, assistant_tokenizer, assistant_model.device, target_vocab_size - ) + self._atm_translator = atm_translator super().__init__( input_ids, assistant_model, @@ -826,7 +647,6 @@ def __init__( # Track sequence lengths and previous assistant IDs self._target_seq_len_with_candidates: int = 0 self._prev_assistant_ids: Optional[torch.LongTensor] = None - self.target_vocab_size = target_vocab_size def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ @@ -848,7 +668,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Generate and process outputs using translator if self._atm_translator.logits_processors is not None: - generation_args["logits_processor"] = self._atm_translator.logits_processors + generation_args["logits_processor"].append(self._atm_translator.logits_processors) self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) # Use translator to convert tokens and logits diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ac09d5516f8a..a68c6292da40 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -17,6 +17,7 @@ import inspect import os import warnings +import weakref from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -339,6 +340,173 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] +class AssistantToTargetTranslator: + """ + Translates token ids and logits between assistant and target model vocabularies. This class is used to handle + vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding, + as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies" + (https://www.arxiv.org/abs/2502.05202). + It maintains mappings between the two vocabularies and handles token/logit conversion. + + Args: + target_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used by the target (main) model. + assistant_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used by the assistant model. + assistant_model_device (`str`, defaults to "cpu"): + The device where the assistant model is located. Used for placing tensors. + target_vocab_size (`int`, *optional*): + The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. + """ + + FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. + SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping. + + def __init__( + self, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() + assistant_model_device: str = "cpu" + ): + self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer + self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer + self._assistant_model_device: str = assistant_model_device + self.target_vocab_size: int = target_vocab_size + self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( + self._get_assistant_to_target_input_ids() + ) + self._suppress_input_ids: list[int] = self._get_suppress_input_ids() + self.logits_processors: Optional[LogitsProcessorList] = None + if len(self._suppress_input_ids) > 0: + # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab + self.logits_processors = LogitsProcessorList( + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] + ) + + def _get_assistant_to_target_input_ids(self): + target_vocab = self._target_tokenizer.get_vocab() + assistant_vocab = self._assistant_tokenizer.get_vocab() + + space_str = " " + target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"] + if len(target_space_ids) > 0: + target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0] + + assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"] + if len(assistant_space_ids) > 0: + assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0] + + if target_space_sign != assistant_space_sign: + # If the assistant tokenizer has a different space sign than the target tokenizer, + # we need to replace the assistant space sign with the target space sign in the assistant_vocab. + assistant_vocab = { + ( + tok.replace(assistant_space_sign, target_space_sign, 1) + if tok.startswith(assistant_space_sign) + else tok + ): idx + for tok, idx in assistant_vocab.items() + } + + max_assistant_index = max(assistant_vocab.values()) + assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int) + target_to_assistant_input_ids: Dict[int, int] = {} + for tok, assistant_id in assistant_vocab.items(): + target_id = target_vocab.get(tok) + if target_id is not None: + assistant_to_target_input_ids[assistant_id] = target_id + target_to_assistant_input_ids[target_id] = assistant_id + return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids + + def _get_suppress_input_ids(self) -> list[int]: + """ + Get the input ids that are in the assistant vocab but not in the target vocab. + """ + return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0] + + def get_target_ids( + self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor + ) -> torch.LongTensor: + """ + Return the target candidate ids that correspond to the assistant candidate ids. + Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. + Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. + """ + + num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] + if num_new_tokens == 0: + return target_input_ids + else: + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] + return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) + + def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: + """ + Return the target logits that correspond to the assistant logits. + """ + + target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) + target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device) + # Mask for valid indices + assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID + # Exclude invalid indices + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] + + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] + + return target_logits + + +class AssistantVocabTranslatorCache: + """ + Cache for `AssistantToTargetTranslator` instances. The instances are computed at + pre-processing time, and this cache allows us to avoid recomputing them. + """ + + _cache = weakref.WeakKeyDictionary() + + @classmethod + def get_translator( + cls, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + assistant_model_device: str = "cpu", + target_vocab_size: Optional[int] = None, + ) -> AssistantToTargetTranslator: + assistant_dict = cls._cache.get(target_tokenizer) + if assistant_dict is None: + assistant_dict = weakref.WeakKeyDictionary() + cls._cache[target_tokenizer] = assistant_dict + + mapping = assistant_dict.get(assistant_tokenizer) + if mapping is None: + mapping = AssistantToTargetTranslator( + target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size + ) + assistant_dict[assistant_tokenizer] = mapping + + return mapping + + @classmethod + def cleanup(cls): + """ + Clean up dead references in the cache. + This removes entries where either the target_tokenizer or assistant_tokenizer + has been garbage collected. + """ + # Remove entries from the outer cache where the target_tokenizer is no longer alive + dead_keys = [key for key in cls._cache if key is None] + for key in dead_keys: + del cls._cache[key] + + # For each assistant_dict, remove entries where assistant_tokenizer is no longer alive + for assistant_dict in cls._cache.values(): + dead_keys = [key for key in assistant_dict if key is None] + for key in dead_keys: + del assistant_dict[key] + class GenerationMixin: """ @@ -860,6 +1028,7 @@ def _get_candidate_generator( ) elif different_tokenizers: if generation_config.do_sample is True: + atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer, self.config.vocab_size, assistant_model.device) candidate_generator = UniversalSpeculativeDecodingGenerator( input_ids=input_ids, assistant_model=assistant_model, @@ -869,8 +1038,7 @@ def _get_candidate_generator( logits_processor=logits_processor, target_tokenizer=target_tokenizer, assistant_tokenizer=assistant_tokenizer, - # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() - target_vocab_size=self.config.vocab_size, + atm_translator=atm_translator, ) elif generation_config.do_sample is False: candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 54ce3b3ee1e2..699b009d3314 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -12,6 +12,7 @@ AssistantVocabTranslatorCache, UniversalSpeculativeDecodingGenerator, ) +from transformers.testing_utils import torch_device class TestAssistantToTargetTranslator(unittest.TestCase): @@ -26,7 +27,7 @@ def setUp(self): self.target_tokenizer.get_vocab.return_value = self.target_vocab self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab - self.assistant_model_device = "cpu" + self.assistant_model_device = torch_device self.target_vocab_size = 6 # Instantiate the class under test @@ -105,7 +106,7 @@ def setUp(self): self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2}) self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3}) self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5}) - self.assistant_model_device = "cpu" + self.assistant_model_device = torch_device self.target_vocab_size = 6 def test_same_instance_for_same_tokenizers(self): @@ -227,12 +228,11 @@ def get_translator(): class TestUniversalSpeculativeDecoding(unittest.TestCase): - device = "cuda" if torch.cuda.is_available() else "cpu" @classmethod def setUpClass(cls): cls.assistant_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to( - cls.device + torch_device ) cls.main_tokenizer = AutoTokenizer.from_pretrained("allenai/Llama-3.1-Tulu-3-8B-SFT") cls.assistant_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")