Skip to content

Conversation

@liangfu
Copy link
Contributor

@liangfu liangfu commented Mar 7, 2025

Add reshape_and_cache function for Neuron KV cache updates

Implements a helper function to write key-value pairs into block-based KV cache tensors. Handles the layout mismatch between:

  • Input tensors: (num_tokens, n_kv_head, d_head)
  • Cache tensors: (num_blocks, n_kv_head, block_size, d_head)

Uses block index calculations and torch.index_put_ to efficiently map and write the inputs into the correct cache positions. Optimized for Neuron's memory layout requirements.

@github-actions
Copy link

github-actions bot commented Mar 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@lingfanyu lingfanyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

Returns:
None: Updates the key_cache and value_cache tensors in-place
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would add an out of bound check for slot_mapping so it < num_blocks * block_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about which of the two cases:
1/ bound check for slot_mapping shape
2/ bound check for slot_mapping values

for 1, i think slot_mapping shape (aka num_batched_tokens) could be greater than num_blocks * block_size (e.g. with DMA skipping), although it can be a rare case.
for 2, values are checked at execution time, not controled by the kernel/compilation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I meant the value. what will happen if some value of slot_mapping go beyond num_blocks * block_size (how do we protect against it?), or could it happen at all?

Copy link
Contributor Author

@liangfu liangfu Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, there must be a bug with scheduler, which allocates the slots. If value exceed the boundary, the neuron runtime would raise out-of-bound (OOB) error.

Comment on lines +875 to +879
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would add a check to make sure they are on the same device.

Copy link

@aarondou aarondou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved with some non-blocking questions.

value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
key_cache (torch.Tensor): Key cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non blocking question: why this rather than (num_blocks, block_size, n_kv_head, d_head)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the KV cache layout has been changed from (num_blocks, block_size, n_kv_head, d_head) to (num_blocks, n_kv_head, block_size, d_head) in #13245, in order to speedup KV cache loading and eliminate unnecessary transpose.

Returns:
None: Updates the key_cache and value_cache tensors in-place
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I meant the value. what will happen if some value of slot_mapping go beyond num_blocks * block_size (how do we protect against it?), or could it happen at all?

@simon-mo simon-mo merged commit c91b64f into vllm-project:main Mar 11, 2025
21 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants