Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions sharktank/sharktank/kernels/wave/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def get_wave_extend_attention_asm(
q_extend_shape: tuple[int],
k_extend_shape: tuple[int],
v_extend_shape: tuple[int],
k_cache_shape: tuple[int],
v_cache_shape: tuple[int],
kv_cache_1_shape: tuple[int],
kv_cache_2_shape: tuple[int],
o_shape: tuple[int],
input_dtype: torch.dtype = torch.float16,
output_dtype: torch.dtype = torch.float32,
Expand All @@ -65,8 +65,8 @@ def get_wave_extend_attention_asm(
q_extend_shape,
k_extend_shape,
v_extend_shape,
k_cache_shape,
v_cache_shape,
kv_cache_1_shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do kv_cache_#. This is ambiguous. It should be clear what these buffer / shape represents (e.g. the k or v component). Its especially important when invoking for wave.

kv_cache_2_shape,
o_shape,
input_dtype=input_dtype,
output_dtype=output_dtype,
Expand Down Expand Up @@ -119,6 +119,7 @@ def get_wave_extend_attention_asm(
MLIRTensor[S, I32],
MLIRTensor[S, I32],
MLIRTensor[N_KV, I32],
MLIRTensor[N_KV, I32],
MLIRTensor[N_Q, H, D_KV, F16],
MLIRTensor[I32],
),
Expand All @@ -128,11 +129,12 @@ def wave_extend_attention(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
kv_cache_buffer_1,
kv_cache_buffer_2,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
out,
max_seq_len,
result=None,
Expand Down Expand Up @@ -182,8 +184,8 @@ def wave_extend_attention(
q_extend.type.shape,
k_extend.type.shape,
v_extend.type.shape,
k_buffer.type.shape,
v_buffer.type.shape,
kv_cache_buffer_1.type.shape,
kv_cache_buffer_2.type.shape,
out.type.shape,
torch.float16,
torch.float16,
Expand All @@ -198,9 +200,9 @@ def wave_extend_attention(
+ wave_asm_body
+ "\n{% endraw %}\n"
+ f"""
util.func private @{{{{kernel_name}}}}(%q_extend : !q_extend, %k_extend : !k_extend, %v_extend : !v_extend, %k_buffer : !k_buffer, %v_buffer : !v_buffer, %qo_indptr : !qo_indptr, %kv_indptr : !kv_indptr, %kv_indices : !kv_indices, %out : !out, %max_seq_len : !max_seq_len) -> !result {{
util.func private @{{{{kernel_name}}}}(%q_extend : !q_extend, %k_extend : !k_extend, %v_extend : !v_extend, %kv_cache_buffer_1 : !kv_cache_buffer_1, %kv_cache_buffer_2 : !kv_cache_buffer_2, %qo_indptr : !qo_indptr, %kv_indptr : !kv_indptr, %k_indices : !k_indices, %v_indices : !v_indices, %out : !out, %max_seq_len : !max_seq_len) -> !result {{
%max_seq_len_i32 = tensor.extract %max_seq_len[] : tensor<i32>
%result = func.call @{wave_kernel_fn_name}(%q_extend, %k_extend, %v_extend, %k_buffer, %v_buffer, %qo_indptr, %kv_indptr, %kv_indices, %out, %max_seq_len_i32) : (!q_extend, !k_extend, !v_extend, !k_buffer, !v_buffer, !qo_indptr, !kv_indptr, !kv_indices, !out, i32) -> !result
%result = func.call @{wave_kernel_fn_name}(%q_extend, %k_extend, %v_extend, %kv_cache_buffer_1, %kv_cache_buffer_2, %qo_indptr, %kv_indptr, %k_indices, %v_indices, %out, %max_seq_len_i32) : (!q_extend, !k_extend, !v_extend, !kv_cache_buffer_1, !kv_cache_buffer_2, !qo_indptr, !kv_indptr, !k_indices, !v_indices, !out, i32) -> !result
util.return %result : !result
}}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def get_extend_attention_kernel(
q_shape: tuple[int],
k_shape: tuple[int],
v_shape: tuple[int],
k_cache_shape: tuple[int],
v_cache_shape: tuple[int],
kv_cache_1_shape: tuple[int],
kv_cache_2_shape: tuple[int],
o_shape: tuple[int],
input_dtype: torch.dtype = torch.float16,
output_dtype: torch.dtype = torch.float32,
Expand Down Expand Up @@ -152,20 +152,22 @@ def get_extend_attention_kernel(
k_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_shape))
v_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_shape))
o_layout = tkl.MemoryLayout(shape=set_dynamic_dim(o_shape))
k_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_cache_shape))
v_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_cache_shape))
kv_cache_1_layout = tkl.MemoryLayout(shape=set_dynamic_dim(kv_cache_1_shape))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with layout.

kv_cache_2_layout = tkl.MemoryLayout(shape=set_dynamic_dim(kv_cache_2_shape))
num_seqs_layout = tkl.MemoryLayout(shape=[None])
kv_indices_layout = tkl.MemoryLayout(shape=[None])
k_indices_layout = tkl.MemoryLayout(shape=[None])
v_indices_layout = tkl.MemoryLayout(shape=[None])

def extend_attention_core(
q,
k,
v,
k_cache,
v_cache,
kv_cache_1,
kv_cache_2,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
custom_mask,
mask_offsets,
c,
Expand Down Expand Up @@ -212,19 +214,19 @@ def first_loop(
target=(h, n_q, d_q),
)
block_indices_v = tkw.read(
kv_indices,
v_indices,
elements_per_thread=LOAD_ELEMS_PER_THREAD_PV,
source=(n_kv + KV_START_IDX,),
target=(n_kv,),
)
block_indices_k = tkw.read(
kv_indices,
k_indices,
elements_per_thread=1,
source=(n_kv + KV_START_IDX,),
target=(n_kv,),
)
k_reg = tkw.read(
k_cache,
kv_cache_1,
elements_per_thread=LOAD_ELEMS_PER_THREAD_QK,
source=(block_indices_k, h_kv // head_ratio, d_q),
target=(h_kv, n_kv, d_q),
Expand Down Expand Up @@ -265,7 +267,7 @@ def first_loop(
d_j = tkw.sum(e_delta, e_init, dim=N_KV)
imm_f16 = tkw.cast(e_delta, wave_input_dtype)
v_reg = tkw.read(
v_cache,
kv_cache_2,
elements_per_thread=LOAD_ELEMS_PER_THREAD_PV,
source=(block_indices_v, h_kv // head_ratio, d_kv),
target=(h_kv, d_kv, n_kv),
Expand Down Expand Up @@ -370,27 +372,29 @@ def extend_attention(
q: tkl.Memory[N_Q, H, D_Q, GLOBAL_ADDRESS_SPACE, wave_input_dtype, q_layout],
k: tkl.Memory[N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_layout],
v: tkl.Memory[N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_layout],
k_cache: tkl.Memory[
N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_cache_layout
kv_cache_1: tkl.Memory[
N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, kv_cache_1_layout
],
v_cache: tkl.Memory[
N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_cache_layout
kv_cache_2: tkl.Memory[
N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, kv_cache_2_layout
],
qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout],
k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout],
v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout],
MAX_EXTEND_SEQ_LEN: tkl.SymbolBind[tkl.i32],
c: tkl.Memory[N_Q, H, D_KV, GLOBAL_ADDRESS_SPACE, wave_output_dtype, o_layout],
):
extend_attention_core(
q,
k,
v,
k_cache,
v_cache,
kv_cache_1,
kv_cache_2,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
None,
None,
c,
Expand All @@ -401,15 +405,16 @@ def extend_attention_custom_mask(
q: tkl.Memory[N_Q, H, D_Q, GLOBAL_ADDRESS_SPACE, wave_input_dtype, q_layout],
k: tkl.Memory[N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_layout],
v: tkl.Memory[N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_layout],
k_cache: tkl.Memory[
N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_cache_layout
kv_cache_1: tkl.Memory[
N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, kv_cache_1_layout
],
v_cache: tkl.Memory[
N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_cache_layout
kv_cache_2: tkl.Memory[
N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, kv_cache_2_layout
],
qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout],
kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout],
k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout],
v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout],
custom_mask: tkl.Memory[
MASK_LEN, GLOBAL_ADDRESS_SPACE, tkl.i8, num_seqs_layout
],
Expand All @@ -421,11 +426,12 @@ def extend_attention_custom_mask(
q,
k,
v,
k_cache,
v_cache,
kv_cache_1,
kv_cache_2,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
custom_mask,
mask_offsets,
c,
Expand Down
57 changes: 41 additions & 16 deletions sharktank/sharktank/kernels/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
device_full,
)
from enum import Enum
from typing import List
from sharktank.types.tensors import QuantizedTensor


class ScoreMod(Enum):
Expand Down Expand Up @@ -100,18 +102,22 @@ def create_extend_attention_inputs(

kv_indptr = device_zeros((B + 1,), dtype=torch.int32)
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32)
k_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32)
v_indices = device_zeros((b_seq_len_prefix.sum().item(),), dtype=torch.int32)

for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
k_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
v_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = device_empty((total_token_num, H_KV, D), dtype=dtype).normal_(
kv_buffer_1 = device_empty((total_token_num, H_KV, D), dtype=dtype).normal_(
mean=0.1, std=0.2
)
v_buffer = device_empty((total_token_num, H_KV, D), dtype=dtype).normal_(
kv_buffer_2 = device_empty((total_token_num, H_KV, D), dtype=dtype).normal_(
mean=0.1, std=0.2
)

Expand All @@ -123,10 +129,10 @@ def create_extend_attention_inputs(
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
k_extend[extend_start:extend_end] = kv_buffer_1[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
v_extend[extend_start:extend_end] = kv_buffer_2[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = device_empty(
Expand Down Expand Up @@ -176,13 +182,14 @@ def create_extend_attention_inputs(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
kv_buffer_1,
kv_buffer_2,
b_req_idx,
b_seq_len,
qo_indptr,
kv_indptr,
kv_indices,
k_indices,
v_indices,
custom_mask,
mask_offsets,
b_start_loc,
Expand Down Expand Up @@ -311,10 +318,28 @@ def ref_extend_attn(
return o_extend


def create_causal_mask(seq_len: int, dtype: torch.dtype, device: str):
# Create a simple attention mask with shape [1, 1, seq_len, seq_len]
# This broadcasts across all batches and heads
mask = torch.triu(torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1)
mask = mask.unsqueeze(0).unsqueeze(0)
mask = mask.to(dtype).to(device=device)
return mask
def create_kv_indices(
page_ids: torch.Tensor,
transformer_block_count: int,
transformer_block_index: int,
block_seq_stride: int,
cache_partitions: List[torch.Tensor | QuantizedTensor],
dtype: torch.dtype,
device: str,
):
all_indices = []

for cache_partition_id, cache_partition in enumerate(cache_partitions):
indices = page_ids
indices = indices * transformer_block_count + transformer_block_index
indices = indices * len(cache_partitions) + cache_partition_id
indices = indices[:, :, None]
indices = (
indices * block_seq_stride
+ torch.arange(block_seq_stride, dtype=dtype, device=device)[None, None, :]
)
indices = indices.flatten(1, 2).to(dtype)
all_indices.append(indices)

k_indices, v_indices = all_indices
return k_indices, v_indices
Loading
Loading