From 84101116d57b359839198ed7cd13ec149ca894a7 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Fri, 20 Dec 2024 06:23:53 +0000 Subject: [PATCH] Remove unused imports and fix style using `make style` --- tests/generation/test_candidate_generator.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index f4e7bba1dd28..80065c160804 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -1,24 +1,20 @@ import gc -import logging import threading import unittest import weakref from unittest.mock import MagicMock -from zmq import device - import numpy as np import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.generation.candidate_generator import ( AssistantToTargetTranslator, AssistantVocabTranslatorCache, AssistedCandidateGeneratorDifferentTokenizers, - UniversalSpeculativeDecodingGenerator + UniversalSpeculativeDecodingGenerator, ) -from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig - class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): def test_no_intersection(self): @@ -313,7 +309,7 @@ def test_basic_generation(self): def test_mismatched_vocabularies(self): """Test handling of mismatched vocabularies between models""" # Create input with tokens present in main but not assistant vocab - # Find a token that is not in the assistant tokenizer but in + # Find a token that is not in the assistant tokenizer but in # the main tokenizer. missing_token = next( token for token in self.main_tokenizer.get_vocab() @@ -321,7 +317,7 @@ def test_mismatched_vocabularies(self): token not in self.main_tokenizer.all_special_tokens and "reserved_" not in token ) - input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) + input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) self.generator.input_ids = input_ids candidates, scores = self.generator.get_candidates(input_ids) self.assertIsNotNone(candidates)