Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ..
--trim_logits \
--batch_size 1 \
--bf16 \
--reuse_cache \
--use_flash_attention \
--flash_attention_recompute \
--flash_attention_causal_mask
Expand Down
22 changes: 18 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module
from .configuration_llama import LlamaConfig


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa

Expand All @@ -58,7 +57,6 @@

import habana_frameworks.torch.core as htcore


def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -387,6 +385,22 @@ def forward(
padding_side,
)

class LlamaKVCache(KVCache):
@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)

def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
Expand All @@ -401,8 +415,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.k_cache = LlamaKVCache()
self.v_cache = LlamaKVCache()

if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
Expand Down
3 changes: 2 additions & 1 deletion optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def update(prev, cur, dim, idx, inp_seq_len):
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
Expand Down