From b47281ffa27191edaefb3203142bfe54c56a7dd9 Mon Sep 17 00:00:00 2001 From: KepingYan Date: Sun, 14 Sep 2025 14:59:55 +0800 Subject: [PATCH] Enable chunked prefill --- vllm_hpu_extension/cache_ops.py | 21 +++++++++++++++++++++ vllm_hpu_extension/utils.py | 13 ++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/vllm_hpu_extension/cache_ops.py b/vllm_hpu_extension/cache_ops.py index ac95336bc..181fbad61 100644 --- a/vllm_hpu_extension/cache_ops.py +++ b/vllm_hpu_extension/cache_ops.py @@ -31,6 +31,27 @@ def swap_blocks_hpu_cpu(src: torch.Tensor, dst: torch.Tensor, block_mapping_t: t dst.index_copy_(0, dst_indices, src.index_select(0, src_indices)) +def insert_or_update_cache_chunked(input, cache, block_indices, block_offsets): + if block_offsets is None: + cache.index_copy_(0, block_indices, input) + else: + if block_offsets.numel() == block_indices.numel(): + cache.index_put_((block_indices, block_offsets), input) + else: + offsets = None + block_size = cache.shape[1] + for i in range(block_indices.shape[0]): + offsets = block_offsets[i * block_size:(i + 1) * block_size - 1] + offset_indices = (offsets == -1) + offset_indices = offset_indices.nonzero(as_tuple=True) + start_index = offsets[0].item() + if offset_indices[0].numel() == 0: + temp_index = offsets[offsets.numel() - 1].item() + else: + temp_index = offset_indices[0][0].item() + end_index = offsets[temp_index - 1].item() + 1 + cache[block_indices[i], start_index:end_index] = input[i][:temp_index] + def swap_blocks(src, dst, block_mapping): if block_mapping.numel() == 0 or dst is None or src is None: return diff --git a/vllm_hpu_extension/utils.py b/vllm_hpu_extension/utils.py index 8c278ad68..df0eb085f 100644 --- a/vllm_hpu_extension/utils.py +++ b/vllm_hpu_extension/utils.py @@ -12,7 +12,7 @@ import habana_frameworks.torch as htorch import torch -from .cache_ops import insert_or_update_cache +from .cache_ops import insert_or_update_cache, insert_or_update_cache_chunked @lru_cache(maxsize=None) @@ -59,8 +59,12 @@ def __init__(self): self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', 'true').lower() == 'true' - def forward(self, input, cache, block_indices, block_offset): - insert_or_update_cache(input, cache, block_indices, block_offset) + def forward(self, input, cache, block_indices, + block_offset, chunk_prefill_enabled=False): + if chunk_prefill_enabled: + insert_or_update_cache_chunked(input, cache, block_indices, block_offset) + else: + insert_or_update_cache(input, cache, block_indices, block_offset) return cache def fetch_from_cache(self, cache, blocks): @@ -68,6 +72,9 @@ def fetch_from_cache(self, cache, blocks): return cache[:blocks.size(0)] else: return cache.index_select(0, blocks) + + def fetch_from_cache_chunked_prefill(self, cache, blocks): + return cache.index_select(0, blocks) class ModuleFusedSDPA(torch.nn.Module):