-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[neuron] add reshape_and_cache #14391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Liangfu Chen <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
lingfanyu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
| Returns: | ||
| None: Updates the key_cache and value_cache tensors in-place | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would add an out of bound check for slot_mapping so it < num_blocks * block_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about which of the two cases:
1/ bound check for slot_mapping shape
2/ bound check for slot_mapping values
for 1, i think slot_mapping shape (aka num_batched_tokens) could be greater than num_blocks * block_size (e.g. with DMA skipping), although it can be a rare case.
for 2, values are checked at execution time, not controled by the kernel/compilation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I meant the value. what will happen if some value of slot_mapping go beyond num_blocks * block_size (how do we protect against it?), or could it happen at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if that's the case, there must be a bug with scheduler, which allocates the slots. If value exceed the boundary, the neuron runtime would raise out-of-bound (OOB) error.
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| key_cache: torch.Tensor, | ||
| value_cache: torch.Tensor, | ||
| slot_mapping: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would add a check to make sure they are on the same device.
aarondou
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approved with some non-blocking questions.
| value (torch.Tensor): Value tensor with shape | ||
| (num_tokens, n_kv_head, d_head) | ||
| key_cache (torch.Tensor): Key cache tensor with shape | ||
| (num_blocks, n_kv_head, block_size, d_head) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non blocking question: why this rather than (num_blocks, block_size, n_kv_head, d_head)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the KV cache layout has been changed from (num_blocks, block_size, n_kv_head, d_head) to (num_blocks, n_kv_head, block_size, d_head) in #13245, in order to speedup KV cache loading and eliminate unnecessary transpose.
| Returns: | ||
| None: Updates the key_cache and value_cache tensors in-place | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I meant the value. what will happen if some value of slot_mapping go beyond num_blocks * block_size (how do we protect against it?), or could it happen at all?
Signed-off-by: Louis Ulmer <[email protected]>
Add
reshape_and_cachefunction for Neuron KV cache updatesImplements a helper function to write key-value pairs into block-based KV cache tensors. Handles the layout mismatch between:
Uses block index calculations and torch.index_put_ to efficiently map and write the inputs into the correct cache positions. Optimized for Neuron's memory layout requirements.