From ca8376f936481e3798e6d5b02e68fd5700f2b809 Mon Sep 17 00:00:00 2001 From: Jimin Date: Wed, 11 Dec 2024 23:02:29 +0000 Subject: [PATCH 1/5] Fix common KVCache not to check token_idx --- .../models/llama/modeling_llama.py | 41 ++++++++++++++++++- .../models/modeling_all_models.py | 3 +- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 285cb14952..87aa5cffb8 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -30,9 +30,48 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module +from ..modeling_all_models import Matmul, apply_customized_rope_module from .configuration_llama import LlamaConfig +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + @staticmethod + def update(prev, cur, dim, idx, inp_seq_len, training=False): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if training is False and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len, self.training) try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 5c3beb8d28..5a78359e3a 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -59,10 +59,11 @@ def update(prev, cur, dim, idx, inp_seq_len): if prev.shape == cur.shape: prev.copy_(cur) return orig_cur - if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: # Initialize prev[:, :, :inp_seq_len, :].copy_(cur) return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" if idx is not None: prev.index_copy_(dim, idx - 1, cur) return prev From 2aade5f00dbe55d2cb20eb5ae175cf9e593f7abb Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 11 Dec 2024 18:41:44 -0800 Subject: [PATCH 2/5] Fix for passing training --- .../models/llama/modeling_llama.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 87aa5cffb8..8369e19bae 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -50,15 +50,16 @@ def allocate(self, inp_seq_len, dtype, device, shape): self.cache.fill_(0) @staticmethod - def update(prev, cur, dim, idx, inp_seq_len, training=False): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if training is False and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur + def update(prev, cur, dim, idx, inp_seq_len, training): + if training is False: + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if idx is None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur if idx is not None: prev.index_copy_(dim, idx - 1, cur) return prev @@ -667,8 +668,8 @@ def pre_attn_forward( value_states = torch.cat((past_key_value[1], value_states), -2) past_key_value = (key_states, value_states) else: - key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len, self.training) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len, self.training) if token_idx is None: past_key_value = (key_states, value_states) From 4232c0d7740f9c37630492d72cfbd585a1399257 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 11 Dec 2024 21:40:47 -0800 Subject: [PATCH 3/5] Change llama KVCache back to 1.14 --- .../models/llama/modeling_llama.py | 85 +++++++++---------- 1 file changed, 41 insertions(+), 44 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 8369e19bae..4fc17b6569 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -33,47 +33,6 @@ from ..modeling_all_models import Matmul, apply_customized_rope_module from .configuration_llama import LlamaConfig -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - @staticmethod - def update(prev, cur, dim, idx, inp_seq_len, training): - if training is False: - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if idx is None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len, self.training) - try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa @@ -98,7 +57,6 @@ def forward(self, cur, dim, idx): import habana_frameworks.torch.core as htcore - def gaudi_llama_rmsnorm_forward(self, hidden_states): """ Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -427,6 +385,45 @@ def forward( padding_side, ) +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + @staticmethod + def update(prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: @@ -668,8 +665,8 @@ def pre_attn_forward( value_states = torch.cat((past_key_value[1], value_states), -2) past_key_value = (key_states, value_states) else: - key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len, self.training) - value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len, self.training) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if token_idx is None: past_key_value = (key_states, value_states) From 0e24ccd9034126db719afbf40c54ee2f19502a22 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 12 Dec 2024 17:10:21 +0000 Subject: [PATCH 4/5] Added LlamaKVCache to just implement update() --- .../models/llama/modeling_llama.py | 31 +++---------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4fc17b6569..b0bc2cbb33 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -30,7 +30,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -from ..modeling_all_models import Matmul, apply_customized_rope_module +from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module from .configuration_llama import LlamaConfig try: @@ -385,22 +385,7 @@ def forward( padding_side, ) -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - +class LlamaKVCache(KVCache): @staticmethod def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -417,14 +402,6 @@ def update(prev, cur, dim, idx, inp_seq_len): else: return torch.cat((prev, cur), dim=dim) - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: return fused_scaled_dot_product_attention_distributed @@ -438,8 +415,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_qk = Matmul() self.matmul_av = Matmul() - self.k_cache = KVCache() - self.v_cache = KVCache() + self.k_cache = LlamaKVCache() + self.v_cache = LlamaKVCache() if hasattr(config, "fused_qkv") and config.fused_qkv: self.num_heads = config.num_attention_heads From d0a13ce0228c8b667b779c09db4469bee3e61243 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 12 Dec 2024 17:54:05 +0000 Subject: [PATCH 5/5] Remove reuse_cache from Llama3-405B measure --- examples/text-generation/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index d9c4ef2bca..0ff2605462 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -487,7 +487,6 @@ QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python .. --trim_logits \ --batch_size 1 \ --bf16 \ ---reuse_cache \ --use_flash_attention \ --flash_attention_recompute \ --flash_attention_causal_mask