Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions vllm_hpu_extension/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ def swap_blocks_hpu_cpu(src: torch.Tensor, dst: torch.Tensor, block_mapping_t: t

dst.index_copy_(0, dst_indices, src.index_select(0, src_indices))

def insert_or_update_cache_chunked(input, cache, block_indices, block_offsets):
if block_offsets is None:
cache.index_copy_(0, block_indices, input)
else:
if block_offsets.numel() == block_indices.numel():
cache.index_put_((block_indices, block_offsets), input)
else:
offsets = None
block_size = cache.shape[1]
for i in range(block_indices.shape[0]):
offsets = block_offsets[i * block_size:(i + 1) * block_size - 1]
offset_indices = (offsets == -1)
offset_indices = offset_indices.nonzero(as_tuple=True)
start_index = offsets[0].item()
if offset_indices[0].numel() == 0:
temp_index = offsets[offsets.numel() - 1].item()
else:
temp_index = offset_indices[0][0].item()
end_index = offsets[temp_index - 1].item() + 1
cache[block_indices[i], start_index:end_index] = input[i][:temp_index]

def swap_blocks(src, dst, block_mapping):
if block_mapping.numel() == 0 or dst is None or src is None:
return
Expand Down
13 changes: 10 additions & 3 deletions vllm_hpu_extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import habana_frameworks.torch as htorch
import torch

from .cache_ops import insert_or_update_cache
from .cache_ops import insert_or_update_cache, insert_or_update_cache_chunked


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -59,15 +59,22 @@ def __init__(self):
self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'

def forward(self, input, cache, block_indices, block_offset):
insert_or_update_cache(input, cache, block_indices, block_offset)
def forward(self, input, cache, block_indices,
block_offset, chunk_prefill_enabled=False):
if chunk_prefill_enabled:
insert_or_update_cache_chunked(input, cache, block_indices, block_offset)
else:
insert_or_update_cache(input, cache, block_indices, block_offset)
return cache

def fetch_from_cache(self, cache, blocks):
if self.use_contiguous_pa:
return cache[:blocks.size(0)]
else:
return cache.index_select(0, blocks)

def fetch_from_cache_chunked_prefill(self, cache, blocks):
return cache.index_select(0, blocks)


class ModuleFusedSDPA(torch.nn.Module):
Expand Down