-
Notifications
You must be signed in to change notification settings - Fork 69
Modify sharktank KVCache and local Wave kernel copy for extend attention #2534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aviator19941
wants to merge
12
commits into
main
Choose a base branch
from
extend_attention_kv_cache
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 8 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
3621f8b
Change layout of kv cache for extend_attention
aviator19941 ec2ea99
Add kv cache test
aviator19941 b93e1e9
Test with kv cache
aviator19941 816a381
Modify wave kernel to read from sharktank kv cache using k_indices/v_…
aviator19941 6835f41
Update kernel and customop args to use new buffers/k_indices/v_indices
aviator19941 24fedd3
Finish test for single request, 2 chunks
aviator19941 ce9e0cb
Fix test to use k_indices and v_indices
aviator19941 64337b6
Merge branch 'main' into extend_attention_kv_cache
aviator19941 a38a5d5
Fix comments
aviator19941 d46e8fc
Uncomment is_mi300x
aviator19941 28ff254
Merge branch 'main' into extend_attention_kv_cache
aviator19941 e4a8b23
Don't use ambiguous kv_cache_# params
aviator19941 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,8 +31,8 @@ def get_extend_attention_kernel( | |
| q_shape: tuple[int], | ||
| k_shape: tuple[int], | ||
| v_shape: tuple[int], | ||
| k_cache_shape: tuple[int], | ||
| v_cache_shape: tuple[int], | ||
| kv_cache_1_shape: tuple[int], | ||
| kv_cache_2_shape: tuple[int], | ||
| o_shape: tuple[int], | ||
| input_dtype: torch.dtype = torch.float16, | ||
| output_dtype: torch.dtype = torch.float32, | ||
|
|
@@ -152,20 +152,22 @@ def get_extend_attention_kernel( | |
| k_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_shape)) | ||
| v_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_shape)) | ||
| o_layout = tkl.MemoryLayout(shape=set_dynamic_dim(o_shape)) | ||
| k_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_cache_shape)) | ||
| v_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_cache_shape)) | ||
| kv_cache_1_layout = tkl.MemoryLayout(shape=set_dynamic_dim(kv_cache_1_shape)) | ||
|
||
| kv_cache_2_layout = tkl.MemoryLayout(shape=set_dynamic_dim(kv_cache_2_shape)) | ||
| num_seqs_layout = tkl.MemoryLayout(shape=[None]) | ||
| kv_indices_layout = tkl.MemoryLayout(shape=[None]) | ||
| k_indices_layout = tkl.MemoryLayout(shape=[None]) | ||
| v_indices_layout = tkl.MemoryLayout(shape=[None]) | ||
|
|
||
| def extend_attention_core( | ||
| q, | ||
| k, | ||
| v, | ||
| k_cache, | ||
| v_cache, | ||
| kv_cache_1, | ||
| kv_cache_2, | ||
| qo_indptr, | ||
| kv_indptr, | ||
| kv_indices, | ||
| k_indices, | ||
| v_indices, | ||
| custom_mask, | ||
| mask_offsets, | ||
| c, | ||
|
|
@@ -212,19 +214,19 @@ def first_loop( | |
| target=(h, n_q, d_q), | ||
| ) | ||
| block_indices_v = tkw.read( | ||
| kv_indices, | ||
| v_indices, | ||
| elements_per_thread=LOAD_ELEMS_PER_THREAD_PV, | ||
| source=(n_kv + KV_START_IDX,), | ||
| target=(n_kv,), | ||
| ) | ||
| block_indices_k = tkw.read( | ||
| kv_indices, | ||
| k_indices, | ||
| elements_per_thread=1, | ||
| source=(n_kv + KV_START_IDX,), | ||
| target=(n_kv,), | ||
| ) | ||
| k_reg = tkw.read( | ||
| k_cache, | ||
| kv_cache_1, | ||
| elements_per_thread=LOAD_ELEMS_PER_THREAD_QK, | ||
| source=(block_indices_k, h_kv // head_ratio, d_q), | ||
| target=(h_kv, n_kv, d_q), | ||
|
|
@@ -265,7 +267,7 @@ def first_loop( | |
| d_j = tkw.sum(e_delta, e_init, dim=N_KV) | ||
| imm_f16 = tkw.cast(e_delta, wave_input_dtype) | ||
| v_reg = tkw.read( | ||
| v_cache, | ||
| kv_cache_2, | ||
| elements_per_thread=LOAD_ELEMS_PER_THREAD_PV, | ||
| source=(block_indices_v, h_kv // head_ratio, d_kv), | ||
| target=(h_kv, d_kv, n_kv), | ||
|
|
@@ -370,27 +372,29 @@ def extend_attention( | |
| q: tkl.Memory[N_Q, H, D_Q, GLOBAL_ADDRESS_SPACE, wave_input_dtype, q_layout], | ||
| k: tkl.Memory[N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_layout], | ||
| v: tkl.Memory[N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_layout], | ||
| k_cache: tkl.Memory[ | ||
| N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_cache_layout | ||
| kv_cache_1: tkl.Memory[ | ||
| N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, kv_cache_1_layout | ||
| ], | ||
| v_cache: tkl.Memory[ | ||
| N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_cache_layout | ||
| kv_cache_2: tkl.Memory[ | ||
| N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, kv_cache_2_layout | ||
| ], | ||
| qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], | ||
| kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], | ||
| kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout], | ||
| k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout], | ||
| v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout], | ||
| MAX_EXTEND_SEQ_LEN: tkl.SymbolBind[tkl.i32], | ||
| c: tkl.Memory[N_Q, H, D_KV, GLOBAL_ADDRESS_SPACE, wave_output_dtype, o_layout], | ||
| ): | ||
| extend_attention_core( | ||
| q, | ||
| k, | ||
| v, | ||
| k_cache, | ||
| v_cache, | ||
| kv_cache_1, | ||
| kv_cache_2, | ||
| qo_indptr, | ||
| kv_indptr, | ||
| kv_indices, | ||
| k_indices, | ||
| v_indices, | ||
| None, | ||
| None, | ||
| c, | ||
|
|
@@ -401,15 +405,16 @@ def extend_attention_custom_mask( | |
| q: tkl.Memory[N_Q, H, D_Q, GLOBAL_ADDRESS_SPACE, wave_input_dtype, q_layout], | ||
| k: tkl.Memory[N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_layout], | ||
| v: tkl.Memory[N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_layout], | ||
| k_cache: tkl.Memory[ | ||
| N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, k_cache_layout | ||
| kv_cache_1: tkl.Memory[ | ||
| N_KV, H_KV, D_Q, ADDRESS_SPACE, wave_input_dtype, kv_cache_1_layout | ||
| ], | ||
| v_cache: tkl.Memory[ | ||
| N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, v_cache_layout | ||
| kv_cache_2: tkl.Memory[ | ||
| N_KV, H_KV, D_KV, ADDRESS_SPACE, wave_input_dtype, kv_cache_2_layout | ||
| ], | ||
| qo_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], | ||
| kv_indptr: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout], | ||
| kv_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, kv_indices_layout], | ||
| k_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, k_indices_layout], | ||
| v_indices: tkl.Memory[N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, v_indices_layout], | ||
| custom_mask: tkl.Memory[ | ||
| MASK_LEN, GLOBAL_ADDRESS_SPACE, tkl.i8, num_seqs_layout | ||
| ], | ||
|
|
@@ -421,11 +426,12 @@ def extend_attention_custom_mask( | |
| q, | ||
| k, | ||
| v, | ||
| k_cache, | ||
| v_cache, | ||
| kv_cache_1, | ||
| kv_cache_2, | ||
| qo_indptr, | ||
| kv_indptr, | ||
| kv_indices, | ||
| k_indices, | ||
| v_indices, | ||
| custom_mask, | ||
| mask_offsets, | ||
| c, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't do kv_cache_#. This is ambiguous. It should be clear what these buffer / shape represents (e.g. the k or v component). Its especially important when invoking for wave.