diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index f4e7bba1dd28..80065c160804 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -1,24 +1,20 @@ import gc -import logging import threading import unittest import weakref from unittest.mock import MagicMock -from zmq import device - import numpy as np import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.generation.candidate_generator import ( AssistantToTargetTranslator, AssistantVocabTranslatorCache, AssistedCandidateGeneratorDifferentTokenizers, - UniversalSpeculativeDecodingGenerator + UniversalSpeculativeDecodingGenerator, ) -from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig - class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): def test_no_intersection(self): @@ -313,7 +309,7 @@ def test_basic_generation(self): def test_mismatched_vocabularies(self): """Test handling of mismatched vocabularies between models""" # Create input with tokens present in main but not assistant vocab - # Find a token that is not in the assistant tokenizer but in + # Find a token that is not in the assistant tokenizer but in # the main tokenizer. missing_token = next( token for token in self.main_tokenizer.get_vocab() @@ -321,7 +317,7 @@ def test_mismatched_vocabularies(self): token not in self.main_tokenizer.all_special_tokens and "reserved_" not in token ) - input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) + input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) self.generator.input_ids = input_ids candidates, scores = self.generator.get_candidates(input_ids) self.assertIsNotNone(candidates)