Skip to content

Cannot use StaticCache with Phi3 #32338

@helunwencser

Description

@helunwencser

System Info

Name: torch
Version: 2.5.0.dev20240716

Name: transformers
Version: 4.44.0.dev0

Who can help?

@ArthurZucker
@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm trying to run phi3 model on edge device via ExecuTorch, where I can only use StaticCache. However, the current phi3 model fails to work with StaticCache.

To reproduce this issue, please run the following script:

import torch
from transformers import Phi3ForCausalLM, StaticCache, AutoTokenizer

end_of_text_token = 32000

class Phi3Mini(torch.nn.Module):

    def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
        super().__init__()
        self.model = model
        self.cache = StaticCache(
            config=model.config,
            max_batch_size=max_batch_size,
            max_cache_len=max_seq_len,
            device=self.model.device,
            dtype=self.model.dtype,
        )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        cache_position: torch.LongTensor = None,
    ) -> torch.FloatTensor:
        return self.model.forward(
            input_ids=input_ids,
            use_cache=True,
            return_dict=True,
            past_key_values=self.cache,
            cache_position=cache_position,
        ).logits

def _generate_token_with_kv_cache(seq_len, model, prompt_tokens):
    print("Generating tokens:", end="", flush=True)

    model = Phi3Mini(model, 1, seq_len + prompt_tokens.shape[-1])

    for input_pos in range(prompt_tokens.shape[-1]):
        result = model.forward(
            input_ids=prompt_tokens[:, input_pos : input_pos + 1],
            cache_position=torch.arange(0, input_pos, device=model.model.device),
        )

    current_token = torch.argmax(result[:, -1, :], dim=-1).item()
    print(f" {current_token}", end="", flush=True)
    generated_tokens = [current_token]

    while current_token != end_of_text_token and len(generated_tokens) < seq_len:
        result = model.forward(
            input_ids=torch.tensor([[current_token]], dtype=torch.long),
            cache_position=torch.arange(
                0,
                prompt_tokens.shape[-1] + len(generated_tokens),
                device=model.model.device,
            ),
        )
        current_token = torch.argmax(result[:, -1, :], dim=-1).item()
        print(f" {current_token}", end="", flush=True)
        generated_tokens.append(current_token)

    print("", flush=True)

    return generated_tokens


def main(
        prompt,
        seq_len,
):
    seed = 42
    torch.manual_seed(seed)
    model_name = "microsoft/Phi-3-mini-4k-instruct"
    model = Phi3ForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    tokens = tokenizer.encode(prompt, return_tensors="pt")

    generated_tokens = _generate_token_with_kv_cache(seq_len, model, tokens)

    print(
        "Generated response: \n {}".format(
            tokenizer.decode(
                generated_tokens,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
        ),
        flush=True,
    )


if __name__ == "__main__":
    main(
        prompt="Tell me a story",
        seq_len=128
    )

It fails with the following error:

/opt/anaconda3/envs/executorch/bin/python /Users/lunwenh/executorch/examples/models/phi-3-mini/test.py 
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.03it/s]
Generating tokens:You are not running the flash-attention implementation, expect numerical differences.
Traceback (most recent call last):
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 92, in <module>
    main(
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 77, in main
    generated_tokens = _generate_token_with_kv_cache(seq_len, model, tokens)
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 38, in _generate_token_with_kv_cache
    result = model.forward(
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 24, in forward
    return self.model.forward(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 1207, in forward
    outputs = self.model(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 1002, in forward
    layer_outputs = decoder_layer(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 739, in forward
    attn_outputs, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 405, in forward
    raise ValueError(
ValueError: Attention weights should be of size (1, 32, 1, 1), but is torch.Size([1, 32, 1, 132])

Process finished with exit code 1

This happens because the current StaticCache implementation does not slice the k_out, v_out upon update and it returns the whole cache up to max_cache_len.

In the long term, #31421 and #30862 should resolve this problem by supporting StaticCache and dynamic length.

For now, removing this size check should make phi3 work with StaticCache.

Expected behavior

After removing the size check, the above mentioned script works well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions