Skip to content

Commit 9de898e

Browse files
ydshiehArthurZucker
authored andcommitted
Update after #41007 (#41014)
* fix * fix --------- Co-authored-by: ydshieh <[email protected]>
1 parent e5a9a1d commit 9de898e

File tree

1 file changed

+2
-18
lines changed

1 file changed

+2
-18
lines changed

tests/models/phimoe/test_modeling_phimoe.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Testing suite for the PyTorch PhiMoE model."""
1616

17-
import copy
1817
import unittest
1918

2019
from parameterized import parameterized
@@ -59,6 +58,7 @@ def forward(
5958
past_key_values=self.cache,
6059
).logits
6160

61+
@torch.no_grad()
6262
@staticmethod
6363
def generate(model: PhimoeForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]:
6464
model = PhimoeMiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1])
@@ -194,19 +194,6 @@ def test_phimoe_instruct_generation(self):
194194

195195
def test_phimoe_instruct_with_static_cache(self):
196196
model = self.get_model()
197-
# Can't run with the real checkpoint, even if offloaded. Let's just use a tiny dummy one
198-
config = copy.deepcopy(model.config)
199-
config.num_hidden_layers = 2
200-
# make `head_dim = 128`
201-
config.hidden_size = 512
202-
config.num_attention_heads = 4
203-
config.num_key_value_heads = 1
204-
config.intermediate_size = 512
205-
config.max_position_embeddinqgs = 64
206-
config.num_local_experts = 4
207-
torch.manual_seed(42)
208-
model = PhimoeForCausalLM(config).to(torch_device)
209-
model.eval()
210197
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
211198

212199
messages = [
@@ -221,12 +208,9 @@ def test_phimoe_instruct_with_static_cache(self):
221208
)
222209

223210
response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, max_seq_len=30)
224-
225211
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
226212

227-
# This is dummy outputs. We actually check if it could run with static cache, not the output quality.
228213
EXPECTED_OUTPUT = [
229-
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> awards"
214+
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> C"
230215
]
231-
232216
self.assertListEqual(output_text, EXPECTED_OUTPUT)

0 commit comments

Comments
 (0)