Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import gc
import logging
import threading
import unittest
import weakref
from unittest.mock import MagicMock

from zmq import device

import numpy as np
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.generation.candidate_generator import (
AssistantToTargetTranslator,
AssistantVocabTranslatorCache,
AssistedCandidateGeneratorDifferentTokenizers,
UniversalSpeculativeDecodingGenerator
UniversalSpeculativeDecodingGenerator,
)

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig


class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
Expand Down Expand Up @@ -313,15 +309,15 @@ def test_basic_generation(self):
def test_mismatched_vocabularies(self):
"""Test handling of mismatched vocabularies between models"""
# Create input with tokens present in main but not assistant vocab
# Find a token that is not in the assistant tokenizer but in
# Find a token that is not in the assistant tokenizer but in
# the main tokenizer.
missing_token = next(
token for token in self.main_tokenizer.get_vocab()
if token not in self.assistant_tokenizer.get_vocab() and
token not in self.main_tokenizer.all_special_tokens and
"reserved_" not in token
)
input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]])
input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]])
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertIsNotNone(candidates)
Expand Down