Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1

def _nixl_handshake(
self,
Expand Down Expand Up @@ -1133,6 +1134,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if base_addr in seen_base_addresses:
continue

# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]

if self.block_size != kernel_block_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a silly question, when will this scenario happen? what is the max kernel block size for CUDA? where is it set?
if self.block_size != kernel_block_size
@jikunshang , may you check ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's backend-dependent, it happens every time the supplied block_size is not one of https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flash_attn.py#L63 (so kernel one is used for physical tensors and block_size becomes only logical)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see

logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size

seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()

Expand Down Expand Up @@ -1751,6 +1768,8 @@ def _read_blocks(
dst_engine_id: str,
request_id: str,
):
local_block_ids = self._logical_to_kernel_block_ids(local_block_ids)
remote_block_ids = self._logical_to_kernel_block_ids(remote_block_ids)
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
Expand Down Expand Up @@ -1876,7 +1895,7 @@ def _read_blocks(
self._failed_recv_reqs.add(request_id)

def _get_block_descs_ids(
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
self, engine_id: str, block_ids: np.ndarray, layer_idx: int | None = None
) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
Expand All @@ -1902,10 +1921,28 @@ def _get_block_descs_ids(

# Compute the desc ids for each block.
region_ids = region_ids[:, None]
block_ids = np.array(block_ids)[None, :]
block_ids = block_ids[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()

def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> np.ndarray:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
block_ids_np = np.array(block_ids)
if self._physical_blocks_per_logical_kv_block == 1:
return block_ids_np
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
kernel_block_ids = (
block_ids_np.reshape(-1, 1) * self._physical_blocks_per_logical_kv_block
+ block_arange
)
return kernel_block_ids.reshape(-1)

def get_backend_aware_kv_block_len(self, layer_idx: int):
"""
Get the block length for one K/V element (K and V have the same size).
Expand Down