Skip to content

Conversation

@aviator19941
Copy link
Collaborator

@aviator19941 aviator19941 commented Oct 16, 2025

This PR modifies the sharktank KVCache by reordering the block_seq_stride to before num_attn_heads in order to allow flattening the first 4 dimensions of the sharktank KVCache to match the Wave KVCache layout. The Wave kernel also expects separate k_cache and v_cache, which sharktank does not currently support, so the local sharktank copy of the Wave extend attention kernel is updated to use a single KVCache block and separate k_indices and v_indices instead of separate k_cache and v_cache blocks and a single kv_indices. In the future, sharktank will need to split the KVCache into k_cache and v_cache. There is also an extend attention test for a single request with 2 chunks that uses the updated KVCache and updated kernel.

Copy link
Contributor

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting a block right now, need to review more carefully, i see some things that need to be resolved before landing, but need to do a careful review.

Signed-off-by: aviator19941 <[email protected]>
@github-actions
Copy link
Contributor

github-actions bot commented Oct 18, 2025

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  sharktank/sharktank/kernels/wave
  utils.py 103-115, 318, 330-345
  sharktank/sharktank/kernels/wave/templates
  extend_attention_kernel.py 155-161, 216-222
  sharktank/sharktank/layers
  paged_attention.py 159-200, 244, 316-319, 419-420
  sharktank/sharktank/ops
  attention_impls.py 257-259, 265-290
  signatures.py 969
  sharktank/tests/kernels/wave
  extend_attention_test.py 201, 242-308, 329-455
Project Total  

This report was generated by python-coverage-comment-action

v_extend_shape,
k_cache_shape,
v_cache_shape,
kv_cache_1_shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do kv_cache_#. This is ambiguous. It should be clear what these buffer / shape represents (e.g. the k or v component). Its especially important when invoking for wave.

o_layout = tkl.MemoryLayout(shape=set_dynamic_dim(o_shape))
k_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(k_cache_shape))
v_cache_layout = tkl.MemoryLayout(shape=set_dynamic_dim(v_cache_shape))
kv_cache_1_layout = tkl.MemoryLayout(shape=set_dynamic_dim(kv_cache_1_shape))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with layout.

head_dim,
attn_dtype,
device,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you are writing this much boiler plate for a test its a sign that either helpers / cleanup should be included. The transposition, data setup, etc inherently means its too complex for clearly describe the test.

extend_attention=True,
)

cache = PagedGQAttention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inverted - the subcomponent will be invoking the kernel so they test should be using it externally.


# Loop through chunks simulating progressive prefill
all_outputs = []
for chunk_id in range(num_chunks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section again screams utility. We should try to clearly describe the setup / components and not rely on a large monolithic setup for a test.

write_page_ids = page_ids[
:, start // block_seq_stride : end // block_seq_stride
]
cache_partitions = [k.cpu(), v.cpu()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is .cpu() required?

cache_partitions = [k.cpu(), v.cpu()]

# Write chunk to KV cache
cache.write(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This write updates the cache allocation with cache_partitions. Unless we are planning to verify the cache updates, this step is unnecessary.

)

# Combine outputs for completeness check
combined_output = torch.cat(all_outputs, dim=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename as extend_attention_output

target_len=seq_len,
attention_dtype=attn_dtype,
).to(device)
sdpa_ref = ops.scaled_dot_product_attention(q=q_sdpa, k=k_sdpa, v=v_sdpa, a=a)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sdpa_ref and extend_attention_output calculation can be extracted to be separate functions under sharktank/tests/utils/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants