Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module
from ..modeling_all_models import Matmul, apply_customized_rope_module
from .configuration_llama import LlamaConfig


Expand Down Expand Up @@ -385,7 +385,22 @@ def forward(
)


class LlamaKVCache(KVCache):
class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
Expand All @@ -402,6 +417,13 @@ def update(prev, cur, dim, idx, inp_seq_len):
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)

def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
Expand All @@ -416,8 +438,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = LlamaKVCache()
self.v_cache = LlamaKVCache()
self.k_cache = KVCache()
self.v_cache = KVCache()

if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
Expand Down