3333from ..modeling_all_models import KVCache , Matmul , apply_customized_rope_module
3434from .configuration_llama import LlamaConfig
3535
36-
3736try :
3837 from habana_frameworks .torch .hpex .kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa
3938
5857
5958import habana_frameworks .torch .core as htcore
6059
61-
6260def gaudi_llama_rmsnorm_forward (self , hidden_states ):
6361 """
6462 Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
@@ -384,6 +382,22 @@ def forward(
384382 padding_side ,
385383 )
386384
385+ class LlamaKVCache (KVCache ):
386+ @staticmethod
387+ def update (prev , cur , dim , idx , inp_seq_len ):
388+ orig_cur = cur
389+ if prev .shape == cur .shape :
390+ prev .copy_ (cur )
391+ return orig_cur
392+ if idx is not None and cur .shape [2 ] > 1 and cur .shape [2 ] <= prev .shape [2 ]:
393+ # Initialize
394+ prev [:, :, :inp_seq_len , :].copy_ (cur )
395+ return orig_cur
396+ if idx is not None :
397+ prev .index_copy_ (dim , idx - 1 , cur )
398+ return prev
399+ else :
400+ return torch .cat ((prev , cur ), dim = dim )
387401
388402def GaudiDistributedAttention (fused_scaled_dot_product_attention , fused_scaled_dot_product_attention_distributed ):
389403 if parallel_state .sequence_parallel_is_initialized () and parallel_state .get_sequence_parallel_world_size () > 1 :
@@ -398,8 +412,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
398412
399413 self .matmul_qk = Matmul ()
400414 self .matmul_av = Matmul ()
401- self .k_cache = KVCache ()
402- self .v_cache = KVCache ()
415+ self .k_cache = LlamaKVCache ()
416+ self .v_cache = LlamaKVCache ()
403417
404418 if hasattr (config , "fused_qkv" ) and config .fused_qkv :
405419 self .num_heads = config .num_attention_heads
0 commit comments