diff --git a/sharktank/sharktank/kernels/wave/extend_attention.py b/sharktank/sharktank/kernels/wave/extend_attention.py index 4f5abff6666..659ac548279 100644 --- a/sharktank/sharktank/kernels/wave/extend_attention.py +++ b/sharktank/sharktank/kernels/wave/extend_attention.py @@ -46,8 +46,8 @@ def get_wave_extend_attention_asm( q_extend_shape: tuple[int], k_extend_shape: tuple[int], v_extend_shape: tuple[int], - k_cache_shape: tuple[int], - v_cache_shape: tuple[int], + k_cache: tuple[int], + v_cache: tuple[int], o_shape: tuple[int], input_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float32, @@ -65,8 +65,8 @@ def get_wave_extend_attention_asm( q_extend_shape, k_extend_shape, v_extend_shape, - k_cache_shape, - v_cache_shape, + k_cache, + v_cache, o_shape, input_dtype=input_dtype, output_dtype=output_dtype, @@ -119,6 +119,7 @@ def get_wave_extend_attention_asm( MLIRTensor[S, I32], MLIRTensor[S, I32], MLIRTensor[N_KV, I32], + MLIRTensor[N_KV, I32], MLIRTensor[N_Q, H, D_KV, F16], MLIRTensor[I32], ), @@ -128,11 +129,12 @@ def wave_extend_attention( q_extend, k_extend, v_extend, - k_buffer, - v_buffer, + k_cache, + v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, out, max_seq_len, result=None, @@ -182,8 +184,8 @@ def wave_extend_attention( q_extend.type.shape, k_extend.type.shape, v_extend.type.shape, - k_buffer.type.shape, - v_buffer.type.shape, + k_cache.type.shape, + v_cache.type.shape, out.type.shape, torch.float16, torch.float16, @@ -198,9 +200,9 @@ def wave_extend_attention( + wave_asm_body + "\n{% endraw %}\n" + f""" - util.func private @{{{{kernel_name}}}}(%q_extend : !q_extend, %k_extend : !k_extend, %v_extend : !v_extend, %k_buffer : !k_buffer, %v_buffer : !v_buffer, %qo_indptr : !qo_indptr, %kv_indptr : !kv_indptr, %kv_indices : !kv_indices, %out : !out, %max_seq_len : !max_seq_len) -> !result {{ + util.func private @{{{{kernel_name}}}}(%q_extend : !q_extend, %k_extend : !k_extend, %v_extend : !v_extend, %k_cache : !k_cache, %v_cache : !v_cache, %qo_indptr : !qo_indptr, %kv_indptr : !kv_indptr, %k_indices : !k_indices, %v_indices : !v_indices, %out : !out, %max_seq_len : !max_seq_len) -> !result {{ %max_seq_len_i32 = tensor.extract %max_seq_len[] : tensor - %result = func.call @{wave_kernel_fn_name}(%q_extend, %k_extend, %v_extend, %k_buffer, %v_buffer, %qo_indptr, %kv_indptr, %kv_indices, %out, %max_seq_len_i32) : (!q_extend, !k_extend, !v_extend, !k_buffer, !v_buffer, !qo_indptr, !kv_indptr, !kv_indices, !out, i32) -> !result + %result = func.call @{wave_kernel_fn_name}(%q_extend, %k_extend, %v_extend, %k_cache, %v_cache, %qo_indptr, %kv_indptr, %k_indices, %v_indices, %out, %max_seq_len_i32) : (!q_extend, !k_extend, !v_extend, !k_cache, !v_cache, !qo_indptr, !kv_indptr, !k_indices, !v_indices, !out, i32) -> !result util.return %result : !result }} """ diff --git a/sharktank/sharktank/kernels/wave/templates/extend_attention_kernel.py b/sharktank/sharktank/kernels/wave/templates/extend_attention_kernel.py index 353af8e74fe..cd2765d4285 100644 --- a/sharktank/sharktank/kernels/wave/templates/extend_attention_kernel.py +++ b/sharktank/sharktank/kernels/wave/templates/extend_attention_kernel.py @@ -155,7 +155,8 @@ def get_extend_attention_kernel( k_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_cache_shape)) v_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_cache_shape)) num_seqs_layout = tkl.MemoryLayout(shape=[None]) - kv_indices_layout = tkl.MemoryLayout(shape=[None]) + k_indices_layout = tkl.MemoryLayout(shape=[None]) + v_indices_layout = tkl.MemoryLayout(shape=[None]) def extend_attention_core( q, @@ -165,7 +166,8 @@ def extend_attention_core( v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, custom_mask, mask_offsets, c, @@ -212,13 +214,13 @@ def first_loop( target=(h, n_q, d_q), ) block_indices_v = tkw.read( - kv_indices, + v_indices, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV, source=(n_kv + KV_START_IDX,), target=(n_kv,), ) block_indices_k = tkw.read( - kv_indices, + k_indices, elements_per_thread=1, source=(n_kv + KV_START_IDX,), target=(n_kv,), @@ -378,7 +380,8 @@ def extend_attention( ], qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], - kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout], + k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout], + v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout], MAX_EXTEND_SEQ_LEN: tkl.SymbolBind[tkl.i32], c: tkl.Memory[N_Q, H, D_KV, GLOBAL_ADDRESS_SPACE, wave_output_dtype, o_layout], ): @@ -390,7 +393,8 @@ def extend_attention( v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, None, None, c, @@ -409,7 +413,8 @@ def extend_attention_custom_mask( ], qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], - kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout], + k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout], + v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout], custom_mask: tkl.Memory[ MASK_LEN, GLOBAL_ADDRESS_SPACE, tkl.i8, num_seqs_layout ], @@ -425,7 +430,8 @@ def extend_attention_custom_mask( v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, custom_mask, mask_offsets, c, diff --git a/sharktank/sharktank/kernels/wave/utils.py b/sharktank/sharktank/kernels/wave/utils.py index 87db8bf98bd..29c9a3b5b2a 100644 --- a/sharktank/sharktank/kernels/wave/utils.py +++ b/sharktank/sharktank/kernels/wave/utils.py @@ -23,6 +23,8 @@ device_full, ) from enum import Enum +from typing import List +from sharktank.types.tensors import QuantizedTensor class ScoreMod(Enum): @@ -100,10 +102,14 @@ def create_extend_attention_inputs( kv_indptr = device_zeros((B + 1,), dtype=torch.int32) kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) - kv_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32) + k_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32) + v_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32) for i in range(B): - kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + k_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] + ) + v_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] ) total_token_num = torch.sum(b_seq_len).item() @@ -182,7 +188,8 @@ def create_extend_attention_inputs( b_seq_len, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, custom_mask, mask_offsets, b_start_loc, @@ -311,10 +318,28 @@ def ref_extend_attn( return o_extend -def create_causal_mask(seq_len: int, dtype: torch.dtype, device: str): - # Create a simple attention mask with shape [1, 1, seq_len, seq_len] - # This broadcasts across all batches and heads - mask = torch.triu(torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1) - mask = mask.unsqueeze(0).unsqueeze(0) - mask = mask.to(dtype).to(device=device) - return mask +def create_kv_indices( + page_ids: torch.Tensor, + transformer_block_count: int, + transformer_block_index: int, + block_seq_stride: int, + cache_partitions: List[torch.Tensor | QuantizedTensor], + dtype: torch.dtype, + device: str, +): + all_indices = [] + + for cache_partition_id, cache_partition in enumerate(cache_partitions): + indices = page_ids + indices = indices * transformer_block_count + transformer_block_index + indices = indices * len(cache_partitions) + cache_partition_id + indices = indices[:, :, None] + indices = ( + indices * block_seq_stride + + torch.arange(block_seq_stride, dtype=dtype, device=device)[None, None, :] + ) + indices = indices.flatten(1, 2).to(dtype) + all_indices.append(indices) + + k_indices, v_indices = all_indices + return k_indices, v_indices diff --git a/sharktank/sharktank/layers/paged_attention.py b/sharktank/sharktank/layers/paged_attention.py index d7817ebbe5f..7bfb897048a 100644 --- a/sharktank/sharktank/layers/paged_attention.py +++ b/sharktank/sharktank/layers/paged_attention.py @@ -132,10 +132,80 @@ def paged_attention_kv_cache_gather( """ return MLIRSpec(mlir) - return paged_attention_kv_cache_gather + @mlir_kernel( + inputs=( + MLIRTensor[ + CACHE_SIZE, + T_BLOCK, + PART, + BLOCK_SEQ_STRIDE, + HEAD_COUNT_KV, + ATTN_HEAD_DIM, + CACHE_TY, + ], + MLIRTensor[BATCH, PAGES, I64], + MLIRTensor[I64], + MLIRTensor[I64], + ), + results=( + MLIRTensor[ + BATCH, PAGES, BLOCK_SEQ_STRIDE, HEAD_COUNT_KV, ATTN_HEAD_DIM, CACHE_TY + ], + ), + ) + def paged_attention_kv_cache_gather_extend_attention( + cache, page_ids, transformer_idx, partition_idx, result + ): + mlir = """ + !cache_slice = tensor<{{[CACHE_SIZE, BLOCK_SEQ_STRIDE, HEAD_COUNT_KV, ATTN_HEAD_DIM]|join('x')}}x!cache_dtype> + + module { + util.func private @{{kernel_name}}(%cache: !cache, + %page_ids: !page_ids, + %transformer_idx: !transformer_idx, + %partition_idx: !partition_idx) -> !result { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Get transformer/partition ids. + %t_id64 = tensor.extract %transformer_idx[] : !transformer_idx + %p_id64 = tensor.extract %partition_idx[] : !partition_idx + %t_id = arith.index_cast %t_id64 : !transformer_idx_dtype to index + %p_id = arith.index_cast %p_id64 : !partition_idx_dtype to index + // Get dynamic dimensions. + %cache_size = tensor.dim %cache, %c0 : !cache + %batches = tensor.dim %page_ids, %c0 : !page_ids + %pages = tensor.dim %page_ids, %c1 : !page_ids + + // Extract a the current transformer block and partition from cache. + %cache_slice = tensor.extract_slice %cache + [0, %t_id, %p_id, 0, 0, 0] + [%cache_size, 1, 1, {{BLOCK_SEQ_STRIDE}}, {{HEAD_COUNT_KV}}, {{ATTN_HEAD_DIM}}] + [1, 1, 1, 1, 1, 1] + : !cache to !cache_slice -kv_cache_gather = KVCacheGatherKernel() + %empty = tensor.empty(%batches, %pages) : !result + + // Gather from cache_slice using page_ids. + %result = iree_linalg_ext.gather + dimension_map = [0] + ins(%cache_slice, %page_ids : !cache_slice, !page_ids) + outs(%empty : !result) -> !result + + util.return %result : !result + } + } + """ + return MLIRSpec(mlir) + + return ( + paged_attention_kv_cache_gather, + paged_attention_kv_cache_gather_extend_attention, + ) + + +kv_cache_gather, kv_cache_gather_extend_attention = KVCacheGatherKernel() class PagedKVCache(KVCache, ABC): @@ -156,6 +226,7 @@ def __init__( block_seq_stride: int = 16, cache_dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + extend_attention: bool = False, ): self.transformer_block_count = transformer_block_count self.attn_head_count = attn_head_count @@ -164,17 +235,27 @@ def __init__( self.block_seq_stride = block_seq_stride self._cache_dtype = cache_dtype self.device = device + self.extend_attention = extend_attention assert cache_partition_count == 2 # Some derived values based on attributes. - self.sub_page_dims = [ - self.transformer_block_count, - self.cache_partition_count, - self.attn_head_count, - self.block_seq_stride, - self.attn_head_dim, - ] + if self.extend_attention: + self.sub_page_dims = [ + self.transformer_block_count, + self.cache_partition_count, + self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, + ] + else: + self.sub_page_dims = [ + self.transformer_block_count, + self.cache_partition_count, + self.attn_head_count, + self.block_seq_stride, + self.attn_head_dim, + ] self.page_slab_flat_dims = math.prod(self.sub_page_dims) @@ -231,11 +312,23 @@ def unwrap_args(*ts): new_ts.append(t) return new_ts - key = kv_cache_gather(*unwrap_args(page_table, page_ids, t_id, key_p_id)) - value = kv_cache_gather(*unwrap_args(page_table, page_ids, t_id, value_p_id)) + if self.extend_attention: + key = kv_cache_gather_extend_attention( + *unwrap_args(page_table, page_ids, t_id, key_p_id) + ) + value = kv_cache_gather_extend_attention( + *unwrap_args(page_table, page_ids, t_id, value_p_id) + ) + else: + key = kv_cache_gather(*unwrap_args(page_table, page_ids, t_id, key_p_id)) + value = kv_cache_gather( + *unwrap_args(page_table, page_ids, t_id, value_p_id) + ) + key = key.transpose(2, 3) + value = value.transpose(2, 3) - key = key.transpose(2, 3).flatten(1, 2) - value = value.transpose(2, 3).flatten(1, 2) + key = key.flatten(1, 2) + value = value.flatten(1, 2) key = pack_raw_tensor(key, k_quantizer, dtype=torch.float16) value = pack_raw_tensor(value, v_quantizer, dtype=torch.float16) @@ -282,7 +375,8 @@ def write( 1, (block_seq_len, self.block_seq_stride) ) cache_partition = cache_partition.flatten(0, 1) - cache_partition = cache_partition.transpose(1, 2) + if not self.extend_attention: + cache_partition = cache_partition.transpose(1, 2) part_block = ops.to(cache_partition, dtype=page_table.dtype) ops.index_copy_(page_table, 0, index, part_block) @@ -321,10 +415,14 @@ def write_timestep( index = page_id index = index * self.transformer_block_count + transformer_block_index index = index * self.cache_partition_count + partitions - index = index * self.attn_head_count + head_offset - index = index * self.block_seq_stride + page_offset + if self.extend_attention: + index = index * self.block_seq_stride + page_offset + index = index * self.attn_head_count + head_offset + else: + index = index * self.attn_head_count + head_offset + index = index * self.block_seq_stride + page_offset + cache_partition.transpose(1, 2) - cache_partition.transpose(1, 2) values = ops.to(cache_partition, dtype=page_table.dtype) ops.index_put_(page_table, indices=(index,), values=values) @@ -480,6 +578,7 @@ def build_cache( cache_dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, parallelism_config: ParallelismConfig | None = None, + extend_attention: bool = False, ) -> PagedKVCache: kwargs = dict( attn_head_count=attn_head_count, @@ -488,6 +587,7 @@ def build_cache( block_seq_stride=block_seq_stride, cache_dtype=cache_dtype, device=device, + extend_attention=extend_attention, ) if parallelism_config is None or parallelism_config.pipeline_size == 1: diff --git a/sharktank/sharktank/ops/attention_impls.py b/sharktank/sharktank/ops/attention_impls.py index 4ce06135a85..3302fac710d 100644 --- a/sharktank/sharktank/ops/attention_impls.py +++ b/sharktank/sharktank/ops/attention_impls.py @@ -251,11 +251,9 @@ def scaled_dot_product_attention_torch( @extend_attention.override(AnyTensor, AnyTensor, AnyTensor, impl_name="wave") -def extend_attention_wave(q, k, v, kv_cache, page_ids, start_positions, seq_lens, impl): - if kv_cache is not None: - return NotImplemented - if page_ids is not None: - return NotImplemented +def extend_attention_wave( + q, k, v, kv_cache, k_indices, v_indices, page_ids, start_positions, seq_lens, impl +): q = unbox_tensor(q) k = unbox_tensor(k) v = unbox_tensor(v) @@ -267,12 +265,15 @@ def extend_attention_wave(q, k, v, kv_cache, page_ids, start_positions, seq_lens q_flat = q.flatten(0, 1).to(torch.float16).to(device) # [B=1*extend_len, H_q, D] k_flat = k.flatten(0, 1).to(torch.float16).to(device) # [B=1*extend_len, H_kv, D] v_flat = v.flatten(0, 1).to(torch.float16).to(device) - k_cache = torch.zeros_like(k) - v_cache = torch.zeros_like(v) - k_cache_flat = ( - k_cache.flatten(0, 1).to(torch.float16).to(device) - ) # [B*prefix_len, H_kv, D] - v_cache_flat = v_cache.flatten(0, 1).to(torch.float16).to(device) + # TODO: don't require passing 2 copies of kv_cache - current kv_cache implementation uses only + # a single allocation, but supporting separate k_cache and v_cache buffers is important for + # future separate k/v allocations [num_pages * t_block_count * cache_partition_count * block_seq_stride, H_kv, D] + if kv_cache is None: + k_cache = torch.zeros_like(k_flat) + v_cache = torch.zeros_like(v_flat) + else: + k_cache = kv_cache.to(device) + v_cache = kv_cache.to(device) extend_len = seq_lens - start_positions extend_len = extend_len.squeeze().to(dtype=torch.int32) b_seq_len_extend = torch.full( @@ -280,8 +281,11 @@ def extend_attention_wave(q, k, v, kv_cache, page_ids, start_positions, seq_lens ) qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) qo_indptr[1:] = torch.cumsum(b_seq_len_extend, dim=0) - kv_indptr = torch.zeros(q.shape[0] + 1, dtype=torch.int32) - kv_indices = torch.zeros(q.shape[0], dtype=torch.int32) + b_seq_len_prefix = torch.full( + (B,), start_positions.item(), dtype=torch.int32, device=device + ) + kv_indptr = torch.zeros(B + 1, dtype=torch.int32) + kv_indptr[1:] = torch.cumsum(b_seq_len_prefix, dim=0) N_q = q_flat.shape[0] output_buffer = torch.zeros((N_q, H_q, D_kv), dtype=torch.float16, device=device) @@ -289,11 +293,12 @@ def extend_attention_wave(q, k, v, kv_cache, page_ids, start_positions, seq_lens q_flat, k_flat, v_flat, - k_cache_flat, - v_cache_flat, + k_cache, + v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, output_buffer, extend_len, ) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 613363941cc..7589e2ce966 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -966,6 +966,8 @@ def extend_attention( k: AnyTensor, v: AnyTensor, kv_cache: Optional[AnyTensor] = None, + k_indices: Optional[AnyTensor] = None, + v_indices: Optional[AnyTensor] = None, page_ids: Optional[AnyTensor] = None, start_positions: Optional[AnyTensor] = None, seq_lens: Optional[AnyTensor] = None, diff --git a/sharktank/tests/kernels/wave/extend_attention_test.py b/sharktank/tests/kernels/wave/extend_attention_test.py index c9c9ff18695..0805371b7ad 100644 --- a/sharktank/tests/kernels/wave/extend_attention_test.py +++ b/sharktank/tests/kernels/wave/extend_attention_test.py @@ -25,13 +25,16 @@ from sharktank.kernels.wave.utils import ( create_extend_attention_inputs, ref_extend_attn, - create_causal_mask, + create_kv_indices, ) from wave_lang.kernel.wave.templates.attention_common import AttentionShape from dataclasses import replace from torch.testing import assert_close from sharktank import ops from sharktank.ops import attention_impls +from sharktank.layers.paged_attention import PagedGQAttention +from sharktank.layers.paged_attention import build_cache +from sharktank.utils.testing import assert_tensor_close @is_mi300x @@ -83,11 +86,12 @@ def forward( q_extend, k_extend, v_extend, - k_buffer, - v_buffer, + k_cache, + v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, output, max_len_extend_tensor, ): @@ -95,11 +99,12 @@ def forward( q_extend, k_extend, v_extend, - k_buffer, - v_buffer, + k_cache, + v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, output, max_len_extend_tensor, ) @@ -119,13 +124,14 @@ def forward( q_extend, k_extend, v_extend, - k_buffer, - v_buffer, + k_cache, + v_cache, b_req_idx, b_seq_len, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, custom_mask, mask_offsets, b_start_loc, @@ -149,11 +155,12 @@ def forward( q_extend, k_extend, v_extend, - k_buffer, - v_buffer, + k_cache, + v_cache, qo_indptr, kv_indptr, - kv_indices, + k_indices, + v_indices, output, torch.tensor( max_len_extend_wave, dtype=torch.int32, device=q_extend.device @@ -193,8 +200,8 @@ def forward( ) ref_output = ref_extend_attn( q_extend=q_extend, - k_buffer=k_buffer, - v_buffer=v_buffer, + k_buffer=k_cache, + v_buffer=v_cache, b_req_idx=b_req_idx, b_start_loc=b_start_loc, b_seq_len=b_seq_len, @@ -215,7 +222,7 @@ class TestOpsExtendAttention: @pytest.mark.skipif(not torch.cuda.is_available(), reason="Needs CUDA/HIP device.") @pytest.mark.parametrize( - "batch, heads, seq_len, head_dim, dtype, device", + "batch, heads, seq_len, head_dim, attn_dtype, device", [ (1, 8, 128, 32, torch.float16, "cuda"), (1, 32, 13, 128, torch.float16, "cuda"), @@ -228,31 +235,58 @@ def test_no_cache( heads, seq_len, head_dim, - dtype, + attn_dtype, device, ): """Test extend attention with various configurations.""" torch.manual_seed(42) - q = torch.randn(batch, seq_len, heads, head_dim, dtype=dtype, device=device) - k = torch.randn(batch, seq_len, heads, head_dim, dtype=dtype, device=device) - v = torch.randn(batch, seq_len, heads, head_dim, dtype=dtype, device=device) + q = torch.randn( + batch, seq_len, heads, head_dim, dtype=attn_dtype, device=device + ) + k = torch.randn( + batch, seq_len, heads, head_dim, dtype=attn_dtype, device=device + ) + v = torch.randn( + batch, seq_len, heads, head_dim, dtype=attn_dtype, device=device + ) q_sdpa = q.transpose(1, 2) k_sdpa = k.transpose(1, 2) v_sdpa = v.transpose(1, 2) - a = create_causal_mask(seq_len, dtype, device) + input_mask = ops.input_mask(torch.tensor([seq_len]), seq_len) + a = ops.attention_mask( + input_mask, + source_len=seq_len, + target_len=seq_len, + attention_dtype=attn_dtype, + ).to(device) sdpa = ops.scaled_dot_product_attention(q=q_sdpa, k=k_sdpa, v=v_sdpa, a=a) seq_lens = torch.tensor([seq_len], dtype=torch.int32) start_positions = torch.tensor([0], dtype=torch.int32) + indices_no_cache = torch.zeros(q.shape[0], dtype=torch.int32) extend_attention = ops.extend_attention( - q=q, k=k, v=v, start_positions=start_positions, seq_lens=seq_lens + q=q, + k=k, + v=v, + kv_cache=None, + k_indices=indices_no_cache, + v_indices=indices_no_cache, + start_positions=start_positions, + seq_lens=seq_lens, ) torch.testing.assert_close(sdpa, extend_attention, atol=1e-3, rtol=1e-3) k_noise = k * 0.05 extend_attention_k_noise = ops.extend_attention( - q=q, k=k_noise, v=v, start_positions=start_positions, seq_lens=seq_lens + q=q, + k=k_noise, + v=v, + kv_cache=None, + k_indices=indices_no_cache, + v_indices=indices_no_cache, + start_positions=start_positions, + seq_lens=seq_lens, ) with pytest.raises(AssertionError): torch.testing.assert_close( @@ -261,9 +295,167 @@ def test_no_cache( v_noise = v * 0.05 extend_attention_v_noise = ops.extend_attention( - q=q, k=k, v=v_noise, start_positions=start_positions, seq_lens=seq_lens + q=q, + k=k, + v=v_noise, + kv_cache=None, + k_indices=indices_no_cache, + v_indices=indices_no_cache, + start_positions=start_positions, + seq_lens=seq_lens, ) with pytest.raises(AssertionError): torch.testing.assert_close( sdpa, extend_attention_v_noise, atol=1e-3, rtol=1e-3 ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Needs CUDA/HIP device.") + @pytest.mark.parametrize( + "batch, heads, seq_len, head_dim, attn_dtype, device", + [ + (1, 4, 32, 16, torch.float16, "cuda"), + ], + ) + def test_extend_kv_cache_single_request_two_chunks( + self, + batch, + heads, + seq_len, + head_dim, + attn_dtype, + device, + ): + """Test extend attention over two sequential chunks for a single request.""" + torch.manual_seed(42) + + transformer_block_count = 8 + transformer_block_index = 3 + block_seq_stride = 4 + chunk_size = 16 # tokens per chunk + num_chunks = seq_len // chunk_size + + # Full QKV for reference + q_full = torch.randn( + batch, seq_len, heads, head_dim, dtype=attn_dtype, device=device + ) + k_full = torch.randn( + batch, seq_len, heads, head_dim, dtype=attn_dtype, device=device + ) + v_full = torch.randn( + batch, seq_len, heads, head_dim, dtype=attn_dtype, device=device + ) + + # Full SDPA reference (no KV cache) + q_sdpa = q_full.transpose(1, 2) + k_sdpa = k_full.transpose(1, 2) + v_sdpa = v_full.transpose(1, 2) + input_mask = ops.input_mask(torch.tensor([seq_len]), seq_len) + a = ops.attention_mask( + input_mask, + source_len=seq_len, + target_len=seq_len, + attention_dtype=attn_dtype, + ).to(device) + sdpa_ref = ops.scaled_dot_product_attention(q=q_sdpa, k=k_sdpa, v=v_sdpa, a=a) + + # Build KV cache + page_count = batch * seq_len // block_seq_stride + kv_cache_extend = build_cache( + transformer_block_count=transformer_block_count, + attn_head_count=heads, + attn_head_dim=head_dim, + block_seq_stride=block_seq_stride, + cache_dtype=attn_dtype, + extend_attention=True, + ) + + cache = PagedGQAttention( + kv_cache=kv_cache_extend, + transformer_block_index=transformer_block_index, + attn_dtype=attn_dtype, + use_rope=True, + attention_chunk_size=None, + ) + + allocation = cache.allocate(page_count=page_count) + page_ids = torch.arange(page_count, dtype=torch.int64).view( + batch, seq_len // block_seq_stride + ) + + wave_kv_cache = kv_cache_extend.unflatten_page_table(allocation).flatten(0, 3) + + # Loop through chunks simulating progressive prefill + all_outputs = [] + for chunk_id in range(num_chunks): + start = chunk_id * chunk_size + end = (chunk_id + 1) * chunk_size + + q = q_full[:, start:end, :, :] + k = k_full[:, start:end, :, :] + v = v_full[:, start:end, :, :] + + write_page_ids = page_ids[ + :, start // block_seq_stride : end // block_seq_stride + ] + cache_partitions = [k.cpu(), v.cpu()] + + # Write chunk to KV cache + cache.write( + allocation, + cache_partitions=cache_partitions, + transformer_block_index=transformer_block_index, + page_ids=write_page_ids, + ) + + # Create indices for extend_attention kernel + if start == 0: + page_ids_prefix = torch.full((page_ids.size(0), 1), 0) + else: + page_ids_prefix = page_ids[:, : (start // block_seq_stride)] + + k_indices, v_indices = create_kv_indices( + page_ids=page_ids_prefix.to(device), + transformer_block_count=transformer_block_count, + transformer_block_index=transformer_block_index, + block_seq_stride=block_seq_stride, + cache_partitions=cache_partitions, + dtype=torch.int32, + device=device, + ) + + seq_lens = torch.tensor([end], dtype=torch.int32, device=device) + start_positions = torch.tensor([start], dtype=torch.int32, device=device) + + # Call extend_attention on current chunk + extend_attention_out = ops.extend_attention( + q=q, + k=k, + v=v, + kv_cache=wave_kv_cache, + k_indices=k_indices.flatten(), + v_indices=v_indices.flatten(), + page_ids=write_page_ids.to(device), + start_positions=start_positions, + seq_lens=seq_lens, + ) + + all_outputs.append(extend_attention_out) + + # Compare new chunk output to reference SDPA output for that range + torch.testing.assert_close( + extend_attention_out, + sdpa_ref[:, :, start:end, :], + atol=1e-3, + rtol=1e-3, + msg=f"Mismatch in chunk {chunk_id}", + ) + + # Combine outputs for completeness check + combined_output = torch.cat(all_outputs, dim=2) + torch.testing.assert_close( + combined_output, + sdpa_ref, + atol=1e-3, + rtol=1e-3, + msg="Combined output does not match full SDPA reference", + )