-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Closed
Labels
Description
System Info
Name: torch
Version: 2.5.0.dev20240716
Name: transformers
Version: 4.44.0.dev0
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I'm trying to run phi3 model on edge device via ExecuTorch, where I can only use StaticCache. However, the current phi3 model fails to work with StaticCache.
To reproduce this issue, please run the following script:
import torch
from transformers import Phi3ForCausalLM, StaticCache, AutoTokenizer
end_of_text_token = 32000
class Phi3Mini(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,
)
def forward(
self,
input_ids: torch.LongTensor = None,
cache_position: torch.LongTensor = None,
) -> torch.FloatTensor:
return self.model.forward(
input_ids=input_ids,
use_cache=True,
return_dict=True,
past_key_values=self.cache,
cache_position=cache_position,
).logits
def _generate_token_with_kv_cache(seq_len, model, prompt_tokens):
print("Generating tokens:", end="", flush=True)
model = Phi3Mini(model, 1, seq_len + prompt_tokens.shape[-1])
for input_pos in range(prompt_tokens.shape[-1]):
result = model.forward(
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
cache_position=torch.arange(0, input_pos, device=model.model.device),
)
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens = [current_token]
while current_token != end_of_text_token and len(generated_tokens) < seq_len:
result = model.forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
cache_position=torch.arange(
0,
prompt_tokens.shape[-1] + len(generated_tokens),
device=model.model.device,
),
)
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens.append(current_token)
print("", flush=True)
return generated_tokens
def main(
prompt,
seq_len,
):
seed = 42
torch.manual_seed(seed)
model_name = "microsoft/Phi-3-mini-4k-instruct"
model = Phi3ForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.encode(prompt, return_tensors="pt")
generated_tokens = _generate_token_with_kv_cache(seq_len, model, tokens)
print(
"Generated response: \n {}".format(
tokenizer.decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
),
flush=True,
)
if __name__ == "__main__":
main(
prompt="Tell me a story",
seq_len=128
)
It fails with the following error:
/opt/anaconda3/envs/executorch/bin/python /Users/lunwenh/executorch/examples/models/phi-3-mini/test.py
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.03it/s]
Generating tokens:You are not running the flash-attention implementation, expect numerical differences.
Traceback (most recent call last):
File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 92, in <module>
main(
File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 77, in main
generated_tokens = _generate_token_with_kv_cache(seq_len, model, tokens)
File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 38, in _generate_token_with_kv_cache
result = model.forward(
File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 24, in forward
return self.model.forward(
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 1207, in forward
outputs = self.model(
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 1002, in forward
layer_outputs = decoder_layer(
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 739, in forward
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 405, in forward
raise ValueError(
ValueError: Attention weights should be of size (1, 32, 1, 1), but is torch.Size([1, 32, 1, 132])
Process finished with exit code 1
This happens because the current StaticCache implementation does not slice the k_out, v_out upon update and it returns the whole cache up to max_cache_len.
In the long term, #31421 and #30862 should resolve this problem by supporting StaticCache and dynamic length.
For now, removing this size check should make phi3 work with StaticCache.
Expected behavior
After removing the size check, the above mentioned script works well.