-
Notifications
You must be signed in to change notification settings - Fork 69
Modify sharktank KVCache and local Wave kernel copy for extend attention #2534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
…indices Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting a block right now, need to review more carefully, i see some things that need to be resolved before landing, but need to do a careful review.
Signed-off-by: aviator19941 <[email protected]>
Coverage reportClick to see where and how coverage changed
This report was generated by python-coverage-comment-action |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Signed-off-by: aviator19941 <[email protected]>
| v_extend_shape, | ||
| k_cache_shape, | ||
| v_cache_shape, | ||
| kv_cache_1_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't do kv_cache_#. This is ambiguous. It should be clear what these buffer / shape represents (e.g. the k or v component). Its especially important when invoking for wave.
| o_layout = tkl.MemoryLayout(shape=set_dynamic_dim(o_shape)) | ||
| k_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_cache_shape)) | ||
| v_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_cache_shape)) | ||
| kv_cache_1_layout = tkl.MemoryLayout(shape=set_dynamic_dim(kv_cache_1_shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with layout.
| head_dim, | ||
| attn_dtype, | ||
| device, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you are writing this much boiler plate for a test its a sign that either helpers / cleanup should be included. The transposition, data setup, etc inherently means its too complex for clearly describe the test.
| extend_attention=True, | ||
| ) | ||
|
|
||
| cache = PagedGQAttention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is inverted - the subcomponent will be invoking the kernel so they test should be using it externally.
|
|
||
| # Loop through chunks simulating progressive prefill | ||
| all_outputs = [] | ||
| for chunk_id in range(num_chunks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section again screams utility. We should try to clearly describe the setup / components and not rely on a large monolithic setup for a test.
Signed-off-by: aviator19941 <[email protected]>
| write_page_ids = page_ids[ | ||
| :, start // block_seq_stride : end // block_seq_stride | ||
| ] | ||
| cache_partitions = [k.cpu(), v.cpu()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is .cpu() required?
| cache_partitions = [k.cpu(), v.cpu()] | ||
|
|
||
| # Write chunk to KV cache | ||
| cache.write( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This write updates the cache allocation with cache_partitions. Unless we are planning to verify the cache updates, this step is unnecessary.
| ) | ||
|
|
||
| # Combine outputs for completeness check | ||
| combined_output = torch.cat(all_outputs, dim=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename as extend_attention_output
| target_len=seq_len, | ||
| attention_dtype=attn_dtype, | ||
| ).to(device) | ||
| sdpa_ref = ops.scaled_dot_product_attention(q=q_sdpa, k=k_sdpa, v=v_sdpa, a=a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sdpa_ref and extend_attention_output calculation can be extracted to be separate functions under sharktank/tests/utils/
This PR modifies the sharktank KVCache by reordering the block_seq_stride to before num_attn_heads in order to allow flattening the first 4 dimensions of the sharktank KVCache to match the Wave KVCache layout. The Wave kernel also expects separate k_cache and v_cache, which sharktank does not currently support, so the local sharktank copy of the Wave extend attention kernel is updated to use a single KVCache block and separate k_indices and v_indices instead of separate k_cache and v_cache blocks and a single kv_indices. In the future, sharktank will need to split the KVCache into k_cache and v_cache. There is also an extend attention test for a single request with 2 chunks that uses the updated KVCache and updated kernel.