-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Fix PhimoeIntegrationTest
#41007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PhimoeIntegrationTest
#41007
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,12 +14,14 @@ | |
|
|
||
| """Testing suite for the PyTorch PhiMoE model.""" | ||
|
|
||
| import copy | ||
| import unittest | ||
|
|
||
| from parameterized import parameterized | ||
|
|
||
| from transformers import PhimoeConfig, StaticCache, is_torch_available | ||
| from transformers.testing_utils import ( | ||
| cleanup, | ||
| require_torch, | ||
| slow, | ||
| torch_device, | ||
|
|
@@ -130,31 +132,47 @@ def test_model_rope_scaling_from_config(self, scaling_type): | |
| @slow | ||
| @require_torch | ||
| class PhimoeIntegrationTest(unittest.TestCase): | ||
| def test_model_phimoe_instruct_logits(self): | ||
| input_ids = { | ||
| "input_ids": torch.tensor( | ||
| [[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device | ||
| model = None | ||
|
|
||
| @classmethod | ||
| def get_model(cls): | ||
| if cls.model is None: | ||
| cls.model = PhimoeForCausalLM.from_pretrained( | ||
| "microsoft/Phi-3.5-MoE-instruct", dtype="auto", device_map="auto" | ||
| ) | ||
| } | ||
| return cls.model | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| del cls.model | ||
| cleanup(torch_device, gc_collect=True) | ||
|
|
||
| def setUp(self): | ||
| cleanup(torch_device, gc_collect=True) | ||
|
|
||
| def tearDown(self): | ||
| cleanup(torch_device, gc_collect=True) | ||
|
|
||
| def test_model_phimoe_instruct_logits(self): | ||
| input_ids = {"input_ids": torch.tensor([[1212, 318, 281, 1672]], dtype=torch.long, device=torch_device)} | ||
|
|
||
| model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct").to(torch_device) | ||
| model = self.get_model() | ||
| model.eval() | ||
|
|
||
| output = model(**input_ids).logits | ||
| with torch.no_grad(): | ||
| output = model(**input_ids).logits | ||
|
|
||
| EXPECTED_OUTPUT = torch.tensor([[-3.5312, -2.5000, -1.2734, 0.3555, -0.7578, -0.4727, 0.5977, -0.4316, | ||
| 0.2256, -1.2188, -1.6797, 0.9961, 3.7656, 11.3125, -1.3828, -4.8438, | ||
| -5.7500, -1.9375, 0.7227, -0.3438, -0.2100, -0.4277, -0.0444, -0.5352, | ||
| -0.6406, -0.1016, -0.4258, -1.0234, 0.4297, -0.6250], | ||
| [-0.9883, 0.1455, -0.4902, 2.3594, 0.7031, 3.1406, 0.4375, 0.2559, | ||
| 0.6172, -2.1094, -1.3359, 2.5938, 4.9062, 10.8125, -0.1094, 1.5781, | ||
| -4.9375, 0.7148, -0.0972, 1.7656, -0.0801, 0.2217, 0.1875, -0.4629, | ||
| 1.5781, 0.3535, 0.0874, 0.6836, -0.0518, -1.2969]]).to(torch_device) # fmt: skip | ||
| EXPECTED_OUTPUT = torch.tensor( | ||
| [ | ||
| [-3.4844, -2.4531, -1.1719, 0.6055, -0.4922, -0.1001, 0.8086, -0.2422, 0.3477, -1.0078], | ||
| [-0.9766, 0.1631, -0.5508, 2.3594, 0.7031, 3.1719, 0.4141, 0.2305, 0.6055, -2.1250], | ||
| ] | ||
| ).to(device=torch_device, dtype=output.dtype) # fmt: skip | ||
|
|
||
| torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4) | ||
| torch.testing.assert_close(output[0, :2, :10], EXPECTED_OUTPUT, rtol=1e-4, atol=1e-4) | ||
vasqu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def test_phimoe_instruct_generation(self): | ||
| model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct") | ||
| model = self.get_model() | ||
| tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct") | ||
|
|
||
| messages = [ | ||
|
|
@@ -166,17 +184,22 @@ def test_phimoe_instruct_generation(self): | |
| ] | ||
| inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") | ||
|
|
||
| outputs = model.generate(inputs, max_new_tokens=32) | ||
| outputs = model.generate(inputs, max_new_tokens=10) | ||
| output_text = tokenizer.batch_decode(outputs) | ||
|
|
||
| EXPECTED_OUTPUT = [ | ||
| "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can be combined in various ways to create tast" | ||
| "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonf", | ||
| ] | ||
|
|
||
| self.assertListEqual(output_text, EXPECTED_OUTPUT) | ||
|
|
||
| def test_phimoe_instruct_with_static_cache(self): | ||
| model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct") | ||
| model = self.get_model() | ||
| # Can't run with the real checkpoint, even if offloaded. Let's just use a tiny dummy one | ||
| config = copy.deepcopy(model.config) | ||
| config.num_hidden_layers = 2 | ||
| torch.manual_seed(42) | ||
| model = type(model)(config) | ||
vasqu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct") | ||
|
|
||
| messages = [ | ||
|
|
@@ -188,12 +211,13 @@ def test_phimoe_instruct_with_static_cache(self): | |
| ] | ||
| inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") | ||
|
|
||
| response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, 64) | ||
| response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, max_seq_len=10) | ||
|
||
|
|
||
| output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device)) | ||
|
|
||
| # This is dummy outputs. We actually check if it could run with static cache, not the output quality. | ||
| EXPECTED_OUTPUT = [ | ||
| "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can" | ||
| "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|>ington" | ||
| ] | ||
|
|
||
| self.assertListEqual(output_text, EXPECTED_OUTPUT) | ||
Uh oh!
There was an error while loading. Please reload this page.