Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks(
np.array(kvcache_manager_blocks)
expected_kernel_blocks = block_table.map_to_kernel_blocks(
np.array(kvcache_manager_blocks),
block_table.blocks_per_kv_block,
block_table._kernel_block_arange,
)
# Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
Expand Down
67 changes: 57 additions & 10 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
Expand Down Expand Up @@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
@dataclass
class ReqMeta:
local_block_ids: list[int]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
Expand Down Expand Up @@ -139,6 +142,7 @@ def add_new_req(
assert load_remote_cache ^ save_to_host
_req = ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
Expand Down Expand Up @@ -935,6 +939,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1

def _nixl_handshake(
self,
Expand Down Expand Up @@ -1133,6 +1138,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if base_addr in seen_base_addresses:
continue

# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]

if self.block_size != kernel_block_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a silly question, when will this scenario happen? what is the max kernel block size for CUDA? where is it set?
if self.block_size != kernel_block_size
@jikunshang , may you check ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's backend-dependent, it happens every time the supplied block_size is not one of https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flash_attn.py#L63 (so kernel one is used for physical tensors and block_size becomes only logical)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see

logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size

seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()

Expand Down Expand Up @@ -1479,7 +1500,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
assert self.use_host_buffer
assert self.copy_blocks is not None

local_block_ids = meta.local_block_ids
local_block_ids = meta.local_physical_block_ids
self.copy_blocks(
self.host_xfer_buffers,
self.device_kv_caches,
Expand All @@ -1492,7 +1513,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_block_ids)),
",".join(map(str, local_block_ids)),
)

def save_kv_to_host(self, metadata: NixlConnectorMetadata):
Expand All @@ -1501,19 +1522,22 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata):
assert self.copy_blocks is not None

for req_id, meta in metadata.reqs_to_save.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_block_ids)),
",".join(map(str, meta.local_physical_block_ids)),
)
# blocking
self.copy_blocks(
self.device_kv_caches,
self.host_xfer_buffers,
meta.local_block_ids,
meta.local_block_ids,
meta.local_physical_block_ids,
meta.local_physical_block_ids,
"d2h",
)

Expand Down Expand Up @@ -1582,7 +1606,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_block_ids
block_ids_to_permute += meta.local_physical_block_ids
if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute)

Expand Down Expand Up @@ -1669,7 +1693,7 @@ def _pop_done_transfers(
req_id,
xfer_state,
)
# mark all blocks for this request as invalid
# mark all (logical)blocks for this request as invalid
if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
Expand All @@ -1686,13 +1710,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.reqs_to_recv.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
meta.remote_block_ids = self._logical_to_kernel_block_ids(
meta.remote_block_ids
)
remote_engine_id = meta.remote_engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id,
remote_engine_id,
len(meta.local_block_ids),
len(meta.local_physical_block_ids),
len(meta.remote_block_ids),
)
# always store metadata for failure recovery
Expand Down Expand Up @@ -1740,7 +1770,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids,
)

Expand Down Expand Up @@ -1867,7 +1897,7 @@ def _read_blocks(
"Marking blocks as invalid.",
request_id,
)
# mark all blocks for this request as invalid
# mark all (logical) blocks for this request as invalid
if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids)
self.xfer_stats.record_failed_transfer()
Expand Down Expand Up @@ -1906,6 +1936,23 @@ def _get_block_descs_ids(
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()

def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
if self._physical_blocks_per_logical_kv_block == 1:
# Noop when physical and logical block sizes are the same
return block_ids
block_ids_np = np.array(block_ids)
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
return BlockTable.map_to_kernel_blocks(
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()

def get_backend_aware_kv_block_len(self, layer_idx: int):
"""
Get the block length for one K/V element (K and V have the same size).
Expand Down
17 changes: 12 additions & 5 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def append_row(
return

if self.use_hybrid_blocks:
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)

num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
Expand Down Expand Up @@ -188,7 +190,12 @@ def clear(self) -> None:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)

def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
@staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs.

Example:
Expand All @@ -203,12 +210,12 @@ def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
if not self.use_hybrid_blocks:
if blocks_per_kv_block == 1:
return kv_manager_block_ids

kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
+ self._kernel_block_arange
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ kernel_block_arange
)

return kernel_block_ids.reshape(-1)
Expand Down