Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/neuron/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from vllm.attention.ops.nki_flash_attn import reshape_and_cache


@pytest.mark.parametrize(
"num_tokens, n_kv_head, d_head, num_blocks, block_size",
[
# Small model configuration (e.g., GPT-2 small)
(32, 12, 64, 4, 128), # Typical sequence processing
(1, 12, 64, 4, 128), # Single token update
(128, 12, 64, 4, 128), # Longer sequence

# Medium model configuration (e.g., GPT-2 medium)
(64, 16, 96, 8, 256), # Standard batch
(256, 16, 96, 8, 256), # Large batch

# Large model configuration (e.g., GPT-3 style)
(48, 32, 128, 16, 512), # Typical processing window
(512, 32, 128, 16, 512), # Full context window

# Edge cases and stress tests
(1024, 8, 32, 32, 32), # Many tokens, small heads
(16, 64, 256, 4, 64), # Few tokens, many heads
(2048, 24, 128, 64, 128), # Large scale test

# Minimal configurations for debugging
(4, 2, 16, 2, 16), # Tiny test case
(1, 1, 8, 1, 8), # Minimal possible
])
def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
block_size):
# Set random seed for reproducibility
torch.manual_seed(42)

# Create CPU tensors for reference implementation
key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
torch.tensor(d_head))
value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
torch.tensor(d_head))
key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens]

# Run reference implementation on CPU
block_indices = torch.div(slot_mapping_cpu,
block_size,
rounding_mode="floor")
block_offsets = slot_mapping_cpu % block_size

for i in range(num_tokens):
block_idx = block_indices[i]
block_offset = block_offsets[i]
key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i]
value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i]

# Create XLA device tensors
device = torch.device('xla')
key = key_cpu.to(device)
value = value_cpu.to(device)
key_cache = torch.zeros_like(key_cache_cpu, device=device)
value_cache = torch.zeros_like(value_cache_cpu, device=device)
slot_mapping = slot_mapping_cpu.to(device)

# Run vectorized implementation on XLA device
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)

# Move results back to CPU for comparison
key_cache_result = key_cache.cpu()
value_cache_result = value_cache.cpu()

# Assert results match
torch.testing.assert_close(key_cache_result,
key_cache_cpu,
rtol=1e-5,
atol=1e-5)
torch.testing.assert_close(value_cache_result,
value_cache_cpu,
rtol=1e-5,
atol=1e-5)
43 changes: 43 additions & 0 deletions vllm/attention/ops/nki_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,3 +869,46 @@ def flash_attn_varlen_nkifunc(

o = flash_paged_attention[1, n_kv_head](**kwargs)
return o


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
Comment on lines +875 to +879
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would add a check to make sure they are on the same device.

) -> None:
"""
Writes key-value pairs to the KV cache at specified positions.

Args:
key (torch.Tensor): Key tensor with shape
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
key_cache (torch.Tensor): Key cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non blocking question: why this rather than (num_blocks, block_size, n_kv_head, d_head)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the KV cache layout has been changed from (num_blocks, block_size, n_kv_head, d_head) to (num_blocks, n_kv_head, block_size, d_head) in #13245, in order to speedup KV cache loading and eliminate unnecessary transpose.

value_cache (torch.Tensor): Value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)

Returns:
None: Updates the key_cache and value_cache tensors in-place
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would add an out of bound check for slot_mapping so it < num_blocks * block_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about which of the two cases:
1/ bound check for slot_mapping shape
2/ bound check for slot_mapping values

for 1, i think slot_mapping shape (aka num_batched_tokens) could be greater than num_blocks * block_size (e.g. with DMA skipping), although it can be a rare case.
for 2, values are checked at execution time, not controled by the kernel/compilation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I meant the value. what will happen if some value of slot_mapping go beyond num_blocks * block_size (how do we protect against it?), or could it happen at all?

Copy link
Contributor Author

@liangfu liangfu Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, there must be a bug with scheduler, which allocates the slots. If value exceed the boundary, the neuron runtime would raise out-of-bound (OOB) error.

block_size = key_cache.size(2)

# Calculate indices with explicit floor division
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_offsets = slot_mapping % block_size

# Update caches using index_put_
key_cache.index_put_(
(block_indices.unsqueeze(1),
torch.arange(key_cache.size(1),
device=key.device), block_offsets.unsqueeze(1)), key)

value_cache.index_put_(
(block_indices.unsqueeze(1),
torch.arange(value_cache.size(1),
device=value.device), block_offsets.unsqueeze(1)), value)