Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions tests/test_universal_assisted_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers.generation.candidate_generator import UniversalSpeculativeDecodingGenerator


device = "cuda" if torch.cuda.is_available() else "cpu"

class TestUniversalSpeculativeDecoding(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Setup main and assistant models
cls.main_model = AutoModelForCausalLM.from_pretrained(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it take <5s to load this 1B model? (Please see @gante's comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it takes about 33 seconds on a T4 machine. I think we can just add the tag @slow as mentioned in the comment. Wdyt?

Copy link
Owner

@keyboardAnt keyboardAnt Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using smaller models? There are a few examples of fast models used in existing Hugging Face tests.

@slow tests run less frequently, so I suggest striving for faster tests.

Copy link
Owner

@keyboardAnt keyboardAnt Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gauravj14 @gauravjain14
Models for testing: https://huggingface.co/hf-internal-testing. For example, hf-internal-testing/tiny-random-gpt2 as used here.

"meta-llama/Llama-3.2-1B-Instruct").to(device)
cls.assistant_model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-gpt2").to(device)
cls.main_tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct")
cls.assistant_tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-gpt2")
cls.generation_config = GenerationConfig()

# Ensure required tokens exist
if cls.main_tokenizer.pad_token_id is None:
cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id
if cls.main_tokenizer.bos_token_id is None:
cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id

def setUp(self):
self.input_ids = torch.tensor([[1, 2, 3]]).to(device)
self.model_kwargs = {
"attention_mask": torch.ones_like(self.input_ids).to(device),
}
self.generator = UniversalSpeculativeDecodingGenerator(
input_ids=self.input_ids,
assistant_model=self.assistant_model,
target_tokenizer=self.main_tokenizer,
assistant_tokenizer=self.assistant_tokenizer,
generation_config=self.generation_config,
model_kwargs=self.model_kwargs,
target_vocab_size=self.main_tokenizer.vocab_size,
)

def test_basic_generation(self):
"""Test basic speculative decoding works"""
input_text = "The quick brown fox"
input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt")
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)

self.assertIsNotNone(candidates)
self.assertIsNotNone(scores)
self.assertTrue(torch.is_tensor(candidates))
self.assertTrue(torch.is_tensor(scores))

def test_mismatched_vocabularies(self):
"""Test handling of mismatched vocabularies between models"""
# Create input with tokens present in main but not assistant vocab
# Find a token that is not in the assistant tokenizer but in
# the main tokenizer.
missing_token = next(
token for token in self.main_tokenizer.get_vocab()
if token not in self.assistant_tokenizer.get_vocab() and
token not in self.main_tokenizer.all_special_tokens and
"reserved_" not in token
)
input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]])
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertIsNotNone(candidates)

def test_speculation_depth(self):
"""Test different speculation depths"""
input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt")
self.generator.input_ids = input_ids

for depth in [1, 8, 17]:
self.generator.num_assistant_tokens = depth
candidates, scores = self.generator.get_candidates(input_ids)
self.assertLessEqual(
candidates.shape[1] - input_ids.shape[1], depth
)

def test_device_consistency(self):
"""Test handling of inputs on different devices"""
if torch.cuda.is_available():
input_ids = torch.tensor([[1, 2, 3]]).to(
self.generator.assistant_model.device)
self.generator.input_ids = input_ids
candidates, scores = self.generator.get_candidates(input_ids)
self.assertEqual(candidates.device, input_ids.device)


if __name__ == '__main__':
unittest.main()