From 736f08992743769a643e6afb254d119a2f29704e Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Thu, 7 Nov 2024 00:37:47 +0200 Subject: [PATCH 1/6] Update KV Cache --- tests/torchtune/modules/test_attention.py | 8 ++---- torchtune/modules/attention.py | 34 +++++++++++------------ torchtune/modules/kv_cache.py | 12 ++++---- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/tests/torchtune/modules/test_attention.py b/tests/torchtune/modules/test_attention.py index 872f6684de..0d9dcb5434 100644 --- a/tests/torchtune/modules/test_attention.py +++ b/tests/torchtune/modules/test_attention.py @@ -123,7 +123,7 @@ def gqa_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -178,7 +178,7 @@ def mha_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -233,7 +233,7 @@ def mqa_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -267,7 +267,6 @@ def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None def test_forward_gqa_kv_cache( self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa ) -> None: - _, _, _, max_seq_len = attn_params_gqa _, seq_len, _ = input.shape @@ -293,7 +292,6 @@ def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None def test_forward_mha_kv_cache( self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha ) -> None: - _, _, _, max_seq_len = attn_params_mha _, seq_len, _ = input.shape diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 879f0679cf..c09cc0c44d 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -164,7 +164,7 @@ def setup_cache( self.kv_cache = KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, dtype=dtype, ) @@ -269,36 +269,36 @@ def forward( if self.pos_embeddings is not None: k = self.pos_embeddings(k, input_pos=input_pos) + # [b, n_h, s, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) + # View + expand + reshape bring num_kv_heads to num_heads for k and v # to match q. - # k: [b, s_y, n_kv, 1, h_d] - # v: [b, s_y, n_kv, 1, h_d] - k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + # k: [b, n_kv, 1, s_y, h_d] + # v: [b, n_kv, 1, s_y, h_d] + k = k.view(b, self.num_kv_heads, 1, s_y, self.head_dim) + v = v.view(b, self.num_kv_heads, 1, s_y, self.head_dim) # If needed, expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim if self.num_heads != self.num_kv_heads: - k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + k = k.expand(b, self.num_kv_heads, q_per_kv, s_y, self.head_dim) + v = v.expand(b, self.num_kv_heads, q_per_kv, s_y, self.head_dim) # [b, s, n_h, h_d] - k = k.reshape(b, s_y, -1, self.head_dim) - v = v.reshape(b, s_y, -1, self.head_dim) - - # [b, n_h, s, h_d] - k = k.transpose(1, 2) - v = v.transpose(1, 2) + k = k.reshape(b, -1, s_y, self.head_dim) + v = v.reshape(b, -1, s_y, self.head_dim) # Normalize k if self.k_norm is not None: k = self.k_norm(k) - # Update key-value cache - if self.kv_cache is not None and self.cache_enabled: - k, v = self.kv_cache.update(k, v) - output = self._attention_call( q, k, diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index facd9703ca..6d6eac3266 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -17,9 +17,7 @@ class KVCache(nn.Module): Args: batch_size (int): batch size model will be run with max_seq_len (int): maximum sequence length model will be run with - num_heads (int): number of heads. We take num_heads instead of num_kv_heads because - the cache is created after we've expanded the key and value tensors to have the - same shape as the query tensor. See attention.py for more details + num_kv_heads (int): number key/value heads. head_dim (int): per-attention head embedding dimension dtype (torch.dtype): dtype for the caches """ @@ -28,12 +26,12 @@ def __init__( self, batch_size: int, max_seq_len: int, - num_heads: int, + num_kv_heads: int, head_dim: int, dtype: torch.dtype, ) -> None: super().__init__() - cache_shape = (batch_size, num_heads, max_seq_len, head_dim) + cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) @@ -53,7 +51,7 @@ def reset(self) -> None: @property def size(self) -> int: - return self.cache_pos[0].item() + return int(self.cache_pos[0].item()) def update( self, k_val: torch.Tensor, v_val: torch.Tensor @@ -66,7 +64,7 @@ def update( already been filled, use ``.reset()``, which will reset the cache to the zero-th position. Example: - >>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16) + >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) >>> cache.update(keys, values) >>> # now positions 0 through 7 are filled From dbb471f8b8877ecf0fbedbc0722a7c7da37a2f8b Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Thu, 7 Nov 2024 17:47:53 +0200 Subject: [PATCH 2/6] Fix attention fwd pass shapes --- .../llama2/scripts/compare_fused_attention.py | 5 ++- .../llama2/scripts/compare_lora_attention.py | 4 +-- torchtune/modules/attention.py | 34 +++++++------------ torchtune/modules/attention_utils.py | 13 ++++++- 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py index 328d1c528f..0c6c3e938a 100644 --- a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py @@ -256,7 +256,6 @@ def compare_attn( max_seq_len: int, use_kv_cache: bool, ): - torch.manual_seed(16) inputs = torch.randn(4, 2048, 4096) @@ -269,8 +268,9 @@ def compare_attn( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - n_kv_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, + dtype=inputs.dtype, ) else: kv_cache = None @@ -330,7 +330,6 @@ def compare_attn( if __name__ == "__main__": - # compare mha mha = { "num_heads": 32, diff --git a/tests/torchtune/models/llama2/scripts/compare_lora_attention.py b/tests/torchtune/models/llama2/scripts/compare_lora_attention.py index c6073297da..fb70c5b464 100644 --- a/tests/torchtune/models/llama2/scripts/compare_lora_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_lora_attention.py @@ -33,7 +33,6 @@ def compare_lora_attention( lora_rank: int, lora_alpha: float, ) -> None: - # make sure we have the right seed for generating outputs # this should match up the seed value set in the corresponding # unit test @@ -68,8 +67,9 @@ def compare_lora_attention( KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - n_kv_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, + dtype=x.dtype, ) if batch_size is not None else None diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index c09cc0c44d..15a6a4f864 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -9,7 +9,11 @@ import torch from torch import nn -from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention +from torchtune.modules.attention_utils import ( + _MaskType, + _sdpa_or_flex_attention, + repeat_interleave, +) from torchtune.modules.kv_cache import KVCache logger = logging.getLogger(__name__) @@ -258,42 +262,30 @@ def forward( else: # Update k and v shape, positional embeddings, and normalization - # k has shape [b, s_y, num_kv_heads * head_dim] - # v has shape [b, s_y, num_kv_heads * head_dim] + # k,v shape [b, s_y, num_kv_heads * head_dim] k = self.k_proj(y) v = self.v_proj(y) # Apply positional embeddings - # k: [b, s_y, n_kv, h_d] + # k,v shape: [b, s_y, n_kv, h_d] k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) if self.pos_embeddings is not None: k = self.pos_embeddings(k, input_pos=input_pos) - # [b, n_h, s, h_d] - k = k.transpose(1, 2) - v = v.transpose(1, 2) + # k,v shape: [b, n_kv, s_y, h_d] + k, v = k.transpose(1, 2), v.transpose(1, 2) # Update key-value cache if self.kv_cache is not None and self.cache_enabled: k, v = self.kv_cache.update(k, v) - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. - - # k: [b, n_kv, 1, s_y, h_d] - # v: [b, n_kv, 1, s_y, h_d] - k = k.view(b, self.num_kv_heads, 1, s_y, self.head_dim) - v = v.view(b, self.num_kv_heads, 1, s_y, self.head_dim) - # If needed, expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim + # k,v shape: [b, n_h, s, h_d] if self.num_heads != self.num_kv_heads: - k = k.expand(b, self.num_kv_heads, q_per_kv, s_y, self.head_dim) - v = v.expand(b, self.num_kv_heads, q_per_kv, s_y, self.head_dim) - - # [b, s, n_h, h_d] - k = k.reshape(b, -1, s_y, self.head_dim) - v = v.reshape(b, -1, s_y, self.head_dim) + k = repeat_interleave(k, dim=1, repeat=q_per_kv) + v = repeat_interleave(v, dim=1, repeat=q_per_kv) # Normalize k if self.k_norm is not None: diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index 8afd4eba71..6190a9d6cf 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -183,7 +183,6 @@ def _attention_call( dropout_p: float, is_causal: bool, ) -> torch.Tensor: - # Flex attention uses the BlockMask # (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168) # instead of a traditional boolean tensor mask. If this is passed in, @@ -247,3 +246,15 @@ def _attention_call( ) return _attention_call + + +def repeat_interleave(x: torch.Tensor, *, dim: int, repeat: int) -> torch.Tensor: + if repeat == 1: + return x + + dim = dim + x.ndim if dim < 0 else dim + + shape = [-1] * (x.ndim + 1) + shape[dim + 1] = repeat + + return x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1) From bb842e0402fd62d8985749b770f9f434628b9858 Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Thu, 7 Nov 2024 18:51:17 +0200 Subject: [PATCH 3/6] fix typo --- torchtune/modules/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 6d6eac3266..be05a88433 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -17,7 +17,7 @@ class KVCache(nn.Module): Args: batch_size (int): batch size model will be run with max_seq_len (int): maximum sequence length model will be run with - num_kv_heads (int): number key/value heads. + num_kv_heads (int): number of key/value heads. head_dim (int): per-attention head embedding dimension dtype (torch.dtype): dtype for the caches """ From 91031f47c86ff6f3d0cd83ef67d367f0bf1f1780 Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Fri, 8 Nov 2024 20:41:36 +0200 Subject: [PATCH 4/6] simplify expand for num_kv_heads --- torchtune/modules/attention.py | 11 ++++------- torchtune/modules/attention_utils.py | 12 ------------ 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 15a6a4f864..741a0d0043 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -9,11 +9,7 @@ import torch from torch import nn -from torchtune.modules.attention_utils import ( - _MaskType, - _sdpa_or_flex_attention, - repeat_interleave, -) +from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache logger = logging.getLogger(__name__) @@ -284,8 +280,9 @@ def forward( # as the query tensor by copying values across the relevant dim # k,v shape: [b, n_h, s, h_d] if self.num_heads != self.num_kv_heads: - k = repeat_interleave(k, dim=1, repeat=q_per_kv) - v = repeat_interleave(v, dim=1, repeat=q_per_kv) + expand_shape = (-1, -1, q_per_kv, -1, -1) + k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) + v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) # Normalize k if self.k_norm is not None: diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index 6190a9d6cf..f2752757eb 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -246,15 +246,3 @@ def _attention_call( ) return _attention_call - - -def repeat_interleave(x: torch.Tensor, *, dim: int, repeat: int) -> torch.Tensor: - if repeat == 1: - return x - - dim = dim + x.ndim if dim < 0 else dim - - shape = [-1] * (x.ndim + 1) - shape[dim + 1] = repeat - - return x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1) From 39b9801d0b3c40922785adc12a2fb0f2ea7a14b9 Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Fri, 8 Nov 2024 21:27:59 +0200 Subject: [PATCH 5/6] cleanup nits --- torchtune/modules/attention.py | 3 ++- torchtune/modules/attention_utils.py | 1 + torchtune/modules/kv_cache.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 741a0d0043..b74c70113e 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -270,7 +270,8 @@ def forward( k = self.pos_embeddings(k, input_pos=input_pos) # k,v shape: [b, n_kv, s_y, h_d] - k, v = k.transpose(1, 2), v.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) # Update key-value cache if self.kv_cache is not None and self.cache_enabled: diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index f2752757eb..8afd4eba71 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -183,6 +183,7 @@ def _attention_call( dropout_p: float, is_causal: bool, ) -> torch.Tensor: + # Flex attention uses the BlockMask # (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168) # instead of a traditional boolean tensor mask. If this is passed in, diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index be05a88433..e96491c22a 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -51,7 +51,7 @@ def reset(self) -> None: @property def size(self) -> int: - return int(self.cache_pos[0].item()) + return self.cache_pos[0].item() def update( self, k_val: torch.Tensor, v_val: torch.Tensor From 5f0d2b70ef6d4e4b23ee3bf6daa0c6e733fa490b Mon Sep 17 00:00:00 2001 From: Mircea Mironenco Date: Sat, 9 Nov 2024 20:56:32 +0200 Subject: [PATCH 6/6] Update gemma-2 kvcache constructor and fix mask type check. --- torchtune/models/gemma2/_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index b00612d032..1b7bf38447 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -149,7 +149,7 @@ def setup_cache( self.kv_cache = KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - num_heads=self.num_heads, + num_kv_heads=self.num_heads, head_dim=self.head_dim, dtype=dtype, ) @@ -211,9 +211,9 @@ def forward( - h_d: head dim """ # until flex attention implementation exists, we do not accept block masks - if (mask is not None) and (type(mask) != torch.Tensor()): + if mask is not None and (not isinstance(mask, torch.Tensor)): raise NotImplementedError( - "Block masks are not implemeted yet, use packed=False" + "Block masks are not implemeted yet, use packed=False." ) # x has shape [b, s_x, d]