Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 4 additions & 184 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.

import copy
import threading
import weakref
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

import numpy as np
Expand All @@ -29,10 +27,10 @@

from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from ..utils import AssistantToTargetTranslator
from .logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
SuppressTokensLogitsProcessor,
)


Expand Down Expand Up @@ -616,181 +614,6 @@ def _process_assistant_outputs(

return new_target_ids


class AssistantToTargetTranslator:
"""
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
(https://www.arxiv.org/abs/2502.05202).
It maintains mappings between the two vocabularies and handles token/logit conversion.

Args:
target_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the target (main) model.
assistant_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the assistant model.
assistant_model_device (`str`, defaults to "cpu"):
The device where the assistant model is located. Used for placing tensors.
target_vocab_size (`int`, *optional*):
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
"""

FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.

def __init__(
self,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
assistant_model_device: str = "cpu",
target_vocab_size: Optional[int] = None,
):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
self._assistant_model_device: str = assistant_model_device
if target_vocab_size is None:
self.target_vocab_size: int = len(self._target_tokenizer.get_vocab())
else:
self.target_vocab_size: int = target_vocab_size
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
self._get_assistant_to_target_input_ids()
)
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: Optional[LogitsProcessorList] = None
if len(self._suppress_input_ids) > 0:
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
self.logits_processors = LogitsProcessorList(
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
)

def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
assistant_vocab = self._assistant_tokenizer.get_vocab()

space_str = " "
target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"]
if len(target_space_ids) > 0:
target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0]

assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"]
if len(assistant_space_ids) > 0:
assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0]

if target_space_sign != assistant_space_sign:
# If the assistant tokenizer has a different space sign than the target tokenizer,
# we need to replace the assistant space sign with the target space sign in the assistant_vocab.
assistant_vocab = {
(
tok.replace(assistant_space_sign, target_space_sign, 1)
if tok.startswith(assistant_space_sign)
else tok
): idx
for tok, idx in assistant_vocab.items()
}

max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
target_to_assistant_input_ids: Dict[int, int] = {}
for tok, assistant_id in assistant_vocab.items():
target_id = target_vocab.get(tok)
if target_id is not None:
assistant_to_target_input_ids[assistant_id] = target_id
target_to_assistant_input_ids[target_id] = assistant_id
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids

def _get_suppress_input_ids(self) -> list[int]:
"""
Get the input ids that are in the assistant vocab but not in the target vocab.
"""
return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0]

def get_target_ids(
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
) -> torch.LongTensor:
"""
Return the target candidate ids that correspond to the assistant candidate ids.
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
"""

num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
if num_new_tokens == 0:
return target_input_ids
else:
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)

def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
"""
Return the target logits that correspond to the assistant logits.
"""

target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
# Mask for valid indices
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]

target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]

return target_logits


class AssistantVocabTranslatorCache:
"""
Cache for `AssistantToTargetTranslator` instances. The instances are computed at
pre-processing time, and this cache allows us to avoid recomputing them.
"""

_lock = threading.Lock()
_cache = weakref.WeakKeyDictionary()

@classmethod
def get_translator(
cls,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
assistant_model_device: str = "cpu",
target_vocab_size: Optional[int] = None,
) -> AssistantToTargetTranslator:
with cls._lock:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
assistant_dict = weakref.WeakKeyDictionary()
cls._cache[target_tokenizer] = assistant_dict

mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
target_tokenizer, assistant_tokenizer, assistant_model_device, target_vocab_size
)
assistant_dict[assistant_tokenizer] = mapping

return mapping

@classmethod
def cleanup(cls):
"""
Clean up dead references in the cache.
This removes entries where either the target_tokenizer or assistant_tokenizer
has been garbage collected.
"""
with cls._lock:
# Remove entries from the outer cache where the target_tokenizer is no longer alive
dead_keys = [key for key in cls._cache if key is None]
for key in dead_keys:
del cls._cache[key]

# For each assistant_dict, remove entries where assistant_tokenizer is no longer alive
for assistant_dict in cls._cache.values():
dead_keys = [key for key in assistant_dict if key is None]
for key in dead_keys:
del assistant_dict[key]


class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers):
"""
`CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers
Expand All @@ -805,14 +628,12 @@ def __init__(
assistant_tokenizer: "PreTrainedTokenizerBase",
generation_config: "GenerationConfig",
model_kwargs: Dict,
target_vocab_size: int,
atm_translator: AssistantToTargetTranslator,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
# Initialize translator before parent class
self._atm_translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer, assistant_tokenizer, assistant_model.device, target_vocab_size
)
self._atm_translator = atm_translator
super().__init__(
input_ids,
assistant_model,
Expand All @@ -826,7 +647,6 @@ def __init__(
# Track sequence lengths and previous assistant IDs
self._target_seq_len_with_candidates: int = 0
self._prev_assistant_ids: Optional[torch.LongTensor] = None
self.target_vocab_size = target_vocab_size

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Expand All @@ -848,7 +668,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,

# Generate and process outputs using translator
if self._atm_translator.logits_processors is not None:
generation_args["logits_processor"] = self._atm_translator.logits_processors
generation_args["logits_processor"].append(self._atm_translator.logits_processors)
self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args)

# Use translator to convert tokens and logits
Expand Down
Loading