Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions sharktank/sharktank/kernels/wave/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def get_wave_extend_attention_asm(
q_extend_shape: tuple[int],
k_extend_shape: tuple[int],
v_extend_shape: tuple[int],
k_cache_shape: tuple[int],
v_cache_shape: tuple[int],
k_cache: tuple[int],
v_cache: tuple[int],
o_shape: tuple[int],
input_dtype: torch.dtype = torch.float16,
output_dtype: torch.dtype = torch.float32,
Expand All @@ -65,8 +65,8 @@ def get_wave_extend_attention_asm(
q_extend_shape,
k_extend_shape,
v_extend_shape,
k_cache_shape,
v_cache_shape,
k_cache,
v_cache,
o_shape,
input_dtype=input_dtype,
output_dtype=output_dtype,
Expand Down Expand Up @@ -119,6 +119,7 @@ def get_wave_extend_attention_asm(
MLIRTensor[S, I32],
MLIRTensor[S, I32],
MLIRTensor[N_KV, I32],
MLIRTensor[N_KV, I32],
MLIRTensor[N_Q, H, D_KV, F16],
MLIRTensor[I32],
),
Expand All @@ -128,11 +129,12 @@ def wave_extend_attention(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
k_cache,
v_cache,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
out,
max_seq_len,
result=None,
Expand Down Expand Up @@ -182,8 +184,8 @@ def wave_extend_attention(
q_extend.type.shape,
k_extend.type.shape,
v_extend.type.shape,
k_buffer.type.shape,
v_buffer.type.shape,
k_cache.type.shape,
v_cache.type.shape,
out.type.shape,
torch.float16,
torch.float16,
Expand All @@ -198,9 +200,9 @@ def wave_extend_attention(
+ wave_asm_body
+ "\n{% endraw %}\n"
+ f"""
util.func private @{{{{kernel_name}}}}(%q_extend : !q_extend, %k_extend : !k_extend, %v_extend : !v_extend, %k_buffer : !k_buffer, %v_buffer : !v_buffer, %qo_indptr : !qo_indptr, %kv_indptr : !kv_indptr, %kv_indices : !kv_indices, %out : !out, %max_seq_len : !max_seq_len) -> !result {{
util.func private @{{{{kernel_name}}}}(%q_extend : !q_extend, %k_extend : !k_extend, %v_extend : !v_extend, %k_cache : !k_cache, %v_cache : !v_cache, %qo_indptr : !qo_indptr, %kv_indptr : !kv_indptr, %k_indices : !k_indices, %v_indices : !v_indices, %out : !out, %max_seq_len : !max_seq_len) -> !result {{
%max_seq_len_i32 = tensor.extract %max_seq_len[] : tensor<i32>
%result = func.call @{wave_kernel_fn_name}(%q_extend, %k_extend, %v_extend, %k_buffer, %v_buffer, %qo_indptr, %kv_indptr, %kv_indices, %out, %max_seq_len_i32) : (!q_extend, !k_extend, !v_extend, !k_buffer, !v_buffer, !qo_indptr, !kv_indptr, !kv_indices, !out, i32) -> !result
%result = func.call @{wave_kernel_fn_name}(%q_extend, %k_extend, %v_extend, %k_cache, %v_cache, %qo_indptr, %kv_indptr, %k_indices, %v_indices, %out, %max_seq_len_i32) : (!q_extend, !k_extend, !v_extend, !k_cache, !v_cache, !qo_indptr, !kv_indptr, !k_indices, !v_indices, !out, i32) -> !result
util.return %result : !result
}}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def get_extend_attention_kernel(
k_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_cache_shape))
v_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_cache_shape))
num_seqs_layout = tkl.MemoryLayout(shape=[None])
kv_indices_layout = tkl.MemoryLayout(shape=[None])
k_indices_layout = tkl.MemoryLayout(shape=[None])
v_indices_layout = tkl.MemoryLayout(shape=[None])

def extend_attention_core(
q,
Expand All @@ -165,7 +166,8 @@ def extend_attention_core(
v_cache,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
custom_mask,
mask_offsets,
c,
Expand Down Expand Up @@ -212,13 +214,13 @@ def first_loop(
target=(h, n_q, d_q),
)
block_indices_v = tkw.read(
kv_indices,
v_indices,
elements_per_thread=LOAD_ELEMS_PER_THREAD_PV,
source=(n_kv + KV_START_IDX,),
target=(n_kv,),
)
block_indices_k = tkw.read(
kv_indices,
k_indices,
elements_per_thread=1,
source=(n_kv + KV_START_IDX,),
target=(n_kv,),
Expand Down Expand Up @@ -378,7 +380,8 @@ def extend_attention(
],
qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout],
k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout],
v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout],
MAX_EXTEND_SEQ_LEN: tkl.SymbolBind[tkl.i32],
c: tkl.Memory[N_Q, H, D_KV, GLOBAL_ADDRESS_SPACE, wave_output_dtype, o_layout],
):
Expand All @@ -390,7 +393,8 @@ def extend_attention(
v_cache,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
None,
None,
c,
Expand All @@ -409,7 +413,8 @@ def extend_attention_custom_mask(
],
qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout],
k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout],
v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout],
custom_mask: tkl.Memory[
MASK_LEN, GLOBAL_ADDRESS_SPACE, tkl.i8, num_seqs_layout
],
Expand All @@ -425,7 +430,8 @@ def extend_attention_custom_mask(
v_cache,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
custom_mask,
mask_offsets,
c,
Expand Down
45 changes: 35 additions & 10 deletions sharktank/sharktank/kernels/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
device_full,
)
from enum import Enum
from typing import List
from sharktank.types.tensors import QuantizedTensor


class ScoreMod(Enum):
Expand Down Expand Up @@ -100,10 +102,14 @@ def create_extend_attention_inputs(

kv_indptr = device_zeros((B + 1,), dtype=torch.int32)
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32)
k_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32)
v_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32)

for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
k_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
v_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
Expand Down Expand Up @@ -182,7 +188,8 @@ def create_extend_attention_inputs(
b_seq_len,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
custom_mask,
mask_offsets,
b_start_loc,
Expand Down Expand Up @@ -311,10 +318,28 @@ def ref_extend_attn(
return o_extend


def create_causal_mask(seq_len: int, dtype: torch.dtype, device: str):
# Create a simple attention mask with shape [1, 1, seq_len, seq_len]
# This broadcasts across all batches and heads
mask = torch.triu(torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1)
mask = mask.unsqueeze(0).unsqueeze(0)
mask = mask.to(dtype).to(device=device)
return mask
def create_kv_indices(
page_ids: torch.Tensor,
transformer_block_count: int,
transformer_block_index: int,
block_seq_stride: int,
cache_partitions: List[torch.Tensor | QuantizedTensor],
dtype: torch.dtype,
device: str,
):
all_indices = []

for cache_partition_id, cache_partition in enumerate(cache_partitions):
indices = page_ids
indices = indices * transformer_block_count + transformer_block_index
indices = indices * len(cache_partitions) + cache_partition_id
indices = indices[:, :, None]
indices = (
indices * block_seq_stride
+ torch.arange(block_seq_stride, dtype=dtype, device=device)[None, None, :]
)
indices = indices.flatten(1, 2).to(dtype)
all_indices.append(indices)

k_indices, v_indices = all_indices
return k_indices, v_indices
Loading
Loading