-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[neuron] add reshape_and_cache #14391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.attention.ops.nki_flash_attn import reshape_and_cache | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "num_tokens, n_kv_head, d_head, num_blocks, block_size", | ||
| [ | ||
| # Small model configuration (e.g., GPT-2 small) | ||
| (32, 12, 64, 4, 128), # Typical sequence processing | ||
| (1, 12, 64, 4, 128), # Single token update | ||
| (128, 12, 64, 4, 128), # Longer sequence | ||
|
|
||
| # Medium model configuration (e.g., GPT-2 medium) | ||
| (64, 16, 96, 8, 256), # Standard batch | ||
| (256, 16, 96, 8, 256), # Large batch | ||
|
|
||
| # Large model configuration (e.g., GPT-3 style) | ||
| (48, 32, 128, 16, 512), # Typical processing window | ||
| (512, 32, 128, 16, 512), # Full context window | ||
|
|
||
| # Edge cases and stress tests | ||
| (1024, 8, 32, 32, 32), # Many tokens, small heads | ||
| (16, 64, 256, 4, 64), # Few tokens, many heads | ||
| (2048, 24, 128, 64, 128), # Large scale test | ||
|
|
||
| # Minimal configurations for debugging | ||
| (4, 2, 16, 2, 16), # Tiny test case | ||
| (1, 1, 8, 1, 8), # Minimal possible | ||
| ]) | ||
| def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks, | ||
| block_size): | ||
| # Set random seed for reproducibility | ||
| torch.manual_seed(42) | ||
|
|
||
| # Create CPU tensors for reference implementation | ||
| key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( | ||
| torch.tensor(d_head)) | ||
| value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( | ||
| torch.tensor(d_head)) | ||
| key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) | ||
| value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) | ||
| slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens] | ||
|
|
||
| # Run reference implementation on CPU | ||
| block_indices = torch.div(slot_mapping_cpu, | ||
| block_size, | ||
| rounding_mode="floor") | ||
| block_offsets = slot_mapping_cpu % block_size | ||
|
|
||
| for i in range(num_tokens): | ||
| block_idx = block_indices[i] | ||
| block_offset = block_offsets[i] | ||
| key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i] | ||
| value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i] | ||
|
|
||
| # Create XLA device tensors | ||
| device = torch.device('xla') | ||
| key = key_cpu.to(device) | ||
| value = value_cpu.to(device) | ||
| key_cache = torch.zeros_like(key_cache_cpu, device=device) | ||
| value_cache = torch.zeros_like(value_cache_cpu, device=device) | ||
| slot_mapping = slot_mapping_cpu.to(device) | ||
|
|
||
| # Run vectorized implementation on XLA device | ||
| reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) | ||
|
|
||
| # Move results back to CPU for comparison | ||
| key_cache_result = key_cache.cpu() | ||
| value_cache_result = value_cache.cpu() | ||
|
|
||
| # Assert results match | ||
| torch.testing.assert_close(key_cache_result, | ||
| key_cache_cpu, | ||
| rtol=1e-5, | ||
| atol=1e-5) | ||
| torch.testing.assert_close(value_cache_result, | ||
| value_cache_cpu, | ||
| rtol=1e-5, | ||
| atol=1e-5) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -869,3 +869,46 @@ def flash_attn_varlen_nkifunc( | |
|
|
||
| o = flash_paged_attention[1, n_kv_head](**kwargs) | ||
| return o | ||
|
|
||
|
|
||
| def reshape_and_cache( | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| key_cache: torch.Tensor, | ||
| value_cache: torch.Tensor, | ||
| slot_mapping: torch.Tensor, | ||
| ) -> None: | ||
| """ | ||
| Writes key-value pairs to the KV cache at specified positions. | ||
|
|
||
| Args: | ||
| key (torch.Tensor): Key tensor with shape | ||
| (num_tokens, n_kv_head, d_head) | ||
| value (torch.Tensor): Value tensor with shape | ||
| (num_tokens, n_kv_head, d_head) | ||
| key_cache (torch.Tensor): Key cache tensor with shape | ||
| (num_blocks, n_kv_head, block_size, d_head) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non blocking question: why this rather than
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the KV cache layout has been changed from |
||
| value_cache (torch.Tensor): Value cache tensor with shape | ||
| (num_blocks, n_kv_head, block_size, d_head) | ||
| slot_mapping (torch.Tensor): Mapping tensor indicating cache positions | ||
| with shape (num_tokens) | ||
|
|
||
| Returns: | ||
| None: Updates the key_cache and value_cache tensors in-place | ||
| """ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would add an out of bound check for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure about which of the two cases: for 1, i think slot_mapping shape (aka num_batched_tokens) could be greater than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I meant the value. what will happen if some value of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if that's the case, there must be a bug with scheduler, which allocates the slots. If value exceed the boundary, the neuron runtime would raise out-of-bound (OOB) error. |
||
| block_size = key_cache.size(2) | ||
|
|
||
| # Calculate indices with explicit floor division | ||
| block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") | ||
| block_offsets = slot_mapping % block_size | ||
|
|
||
| # Update caches using index_put_ | ||
| key_cache.index_put_( | ||
| (block_indices.unsqueeze(1), | ||
| torch.arange(key_cache.size(1), | ||
| device=key.device), block_offsets.unsqueeze(1)), key) | ||
|
|
||
| value_cache.index_put_( | ||
| (block_indices.unsqueeze(1), | ||
| torch.arange(value_cache.size(1), | ||
| device=value.device), block_offsets.unsqueeze(1)), value) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would add a check to make sure they are on the same device.