From dcc1adedef9586a053579a06989eb503622f60bc Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Mon, 13 Oct 2025 18:33:08 -0700 Subject: [PATCH 01/27] enable hetero blocksize Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 80 +++++++++++++------ 1 file changed, 56 insertions(+), 24 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8c4c82f76ff2..8a5b6bf72cba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -102,6 +102,7 @@ class NixlAgentMetadata( block_lens: list[int] attn_backend_name: str kv_cache_layout: str + block_size: int @dataclass @@ -1010,6 +1011,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout, + block_size=self.block_size, ) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( @@ -1094,40 +1096,37 @@ def add_remote_agent( is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 remote_block_len = nixl_agent_meta.block_lens[0] + # NOTE(Chendi): we want to support remote and local with different block_size. + # To achieve this goal, we need to make sure that + # remote_block_lens * remote_block_size = local_block_lens * local_block_size + remote_block_size = nixl_agent_meta.block_size + block_size_ratio = remote_block_size / self.block_size + self.block_size_ratio = block_size_ratio if self.use_mla or is_kv_replicated: # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( + assert self.block_len_per_layer[0] * block_size_ratio == remote_block_len, ( "KV cache sizes must match between P and D when replicated" ) - remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( "All remote layers must have the same block size" ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) - if self._use_flashinfer: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 if tp_ratio > 1: # Heterogeneous TP expects same kv_cache_layout. assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout if self.device_type == "xpu": raise ValueError("Heterogeneous TP is not supported on XPU") - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + assert ( + remote_block_len + == self.block_len_per_layer[0] * tp_ratio * block_size_ratio + ), ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks @@ -1150,7 +1149,11 @@ def add_remote_agent( if not (self.use_mla or is_kv_replicated) else 0 ) - for block_id in range(nixl_agent_meta.num_blocks): + # NOTE(Chendi): In case remote and local use different block_size. + local_num_blocks = int(nixl_agent_meta.num_blocks * block_size_ratio) + # print(f"{local_num_blocks*i=} {local_num_blocks=} " + # f "{nixl_agent_meta.num_blocks=} {nixl_agent_meta.block_lens[0]=}") + for block_id in range(local_num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to @@ -1161,8 +1164,10 @@ def add_remote_agent( if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_lens[i] + for block_id in range(local_num_blocks): + block_offset = block_id * math.ceil( + nixl_agent_meta.block_lens[i] * block_size_ratio + ) addr = base_addr + block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) @@ -1373,11 +1378,15 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): continue # Handshake already completed, start async read xfer. - self._read_blocks_for_req(req_id, meta) + # FIXME(Chendi): should store per engine + self._read_blocks_for_req(req_id, meta, self.block_size_ratio) # Start transfers for requests whose handshakes have now finished. while not self._ready_requests.empty(): - self._read_blocks_for_req(*self._ready_requests.get_nowait()) + # FIXME(Chendi): should store per engine + self._read_blocks_for_req( + *self._ready_requests.get_nowait(), self.block_size_ratio + ) # Keep around the requests that have been part of a batch. This is # needed because async scheduling pushes the misalignment between the @@ -1399,7 +1408,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): if req_id in self._reqs_to_process: self._reqs_to_send[req_id] = expiration_time - def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta, block_size_ratio: float): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", meta.remote_engine_id, @@ -1410,6 +1419,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): dst_engine_id=meta.remote_engine_id, local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, + block_size_ratio=block_size_ratio, ) def _read_blocks( @@ -1418,7 +1428,22 @@ def _read_blocks( remote_block_ids: list[int], dst_engine_id: str, request_id: str, + block_size_ratio: float, ): + # FIXME(Chendi): Very naive codes to re-calculate remote block + # Only works for remote block_size < local block_size now, + # remote block_size > local block_size needs extra map + # print(f"before {local_block_ids=} {remote_block_ids=}") + block_size_ratio_inv = int(1 / block_size_ratio) + remote_block_ids = [ + i // block_size_ratio_inv + for i in remote_block_ids + if i % block_size_ratio_inv == 0 + ] + if len(remote_block_ids) < len(local_block_ids): + remote_block_ids.append(remote_block_ids[-1] + 1) + # print(f"after {local_block_ids=} {remote_block_ids=}") + # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1472,7 +1497,7 @@ def _read_blocks( if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids + dst_engine_id, remote_block_ids, block_size_ratio=block_size_ratio ) local_block_descs_ids = self._get_block_descs_ids( self.engine_id, local_block_ids @@ -1499,7 +1524,10 @@ def _read_blocks( self.engine_id, layer_local_block_ids, layer_idx ) layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, layer_remote_block_ids, layer_idx + dst_engine_id, + layer_remote_block_ids, + layer_idx, + block_size_ratio=block_size_ratio, ) local_descs_list.append(layer_local_desc_ids) @@ -1542,7 +1570,11 @@ def _read_blocks( self._failed_recv_reqs.add(request_id) def _get_block_descs_ids( - self, engine_id: str, block_ids: list[int], layer_idx: int | None = None + self, + engine_id: str, + block_ids: list[int], + layer_idx: int | None = None, + block_size_ratio: float = 1, ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -1569,7 +1601,7 @@ def _get_block_descs_ids( # Compute the desc ids for each block. region_ids = region_ids[:, None] block_ids = np.array(block_ids)[None, :] - descs_ids = region_ids * num_blocks + block_ids + descs_ids = region_ids * int(num_blocks * block_size_ratio) + block_ids return descs_ids.flatten() def get_backend_aware_kv_block_len(self, layer_idx: int): From 87b1de85a1365079395c863cfd485b13799ba538 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 21 Oct 2025 07:17:49 -0700 Subject: [PATCH 02/27] update and cleanup Current codes only works for NHD Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 113 +++++++++++------- vllm/platforms/cuda.py | 2 + 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index b98eabe590e4..de4ad6bc243b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1111,6 +1111,7 @@ def add_remote_agent( # remote_block_lens * remote_block_size = local_block_lens * local_block_size remote_block_size = nixl_agent_meta.block_size block_size_ratio = remote_block_size / self.block_size + block_size_ratio_inv = self.block_size / remote_block_size self.block_size_ratio = block_size_ratio if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: if ( @@ -1171,6 +1172,9 @@ def add_remote_agent( self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + # NOTE(Chendi): In case remote and local use different block_size. + local_num_blocks = int(nixl_agent_meta.num_blocks * block_size_ratio) + self.remote_remain_blocks = [] # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) @@ -1179,27 +1183,43 @@ def add_remote_agent( if not (self.use_mla or is_kv_replicated) else 0 ) - # NOTE(Chendi): In case remote and local use different block_size. - local_num_blocks = int(nixl_agent_meta.num_blocks * block_size_ratio) - # print(f"{local_num_blocks*i=} {local_num_blocks=} " - # f "{nixl_agent_meta.num_blocks=} {nixl_agent_meta.block_lens[0]=}") + + if block_size_ratio < 1: + # we skip block 0 in each layer, block 0 at remote is using smaller size + # in order to correct mapping, we can skip block0 from remote + # example for block_size_ratio = 0.25 + # remote block 0 -> skip + # remote block 1, 2, 3, 4 => local blocl 0 + # remote block 5, 6, 7, 8 => local blocl 1 + # Also, we keep track of last few remote blocks which can't match to + # last local block size and use them in read_block for mapping + # remote/local block_id. + shift_for_remote_block_0 = remote_block_len + remain_nblocks_remote = nixl_agent_meta.num_blocks - int( + local_num_blocks * block_size_ratio_inv + ) + self.remote_remain_blocks.append(remain_nblocks_remote) + local_block_len = int(self.block_len_per_layer[i] * tp_ratio) for block_id in range(local_num_blocks): - block_offset = block_id * nixl_agent_meta.block_lens[i] + block_offset = block_id * local_block_len # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset + addr = base_addr + block_offset + rank_offset + shift_for_remote_block_0 # (addr, len, device id) blocks_data.append((addr, kv_block_len, remote_tp_rank)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. for block_id in range(local_num_blocks): - block_offset = block_id * math.ceil( - nixl_agent_meta.block_lens[i] * block_size_ratio + block_offset = block_id * local_block_len + addr = ( + base_addr + + block_offset + + rank_offset + + shift_for_remote_block_0 ) - addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + v_addr = addr + local_block_len // 2 blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) logger.debug( @@ -1320,8 +1340,8 @@ def get_finished(self) -> tuple[set[str], set[str]]: # clean up metadata for completed requests for req_id in done_recving: - meta = self._recving_metadata.pop(req_id, None) - if self.use_host_buffer and meta: + if self.use_host_buffer: + meta = self._recving_metadata.pop(req_id) self.sync_recved_kv_to_device(req_id, meta) # Handle timeout to avoid stranding blocks on remote. @@ -1504,19 +1524,7 @@ def _read_blocks( request_id: str, block_size_ratio: float, ): - # FIXME(Chendi): Very naive codes to re-calculate remote block - # Only works for remote block_size < local block_size now, - # remote block_size > local block_size needs extra map - # print(f"before {local_block_ids=} {remote_block_ids=}") - block_size_ratio_inv = int(1 / block_size_ratio) - remote_block_ids = [ - i // block_size_ratio_inv - for i in remote_block_ids - if i % block_size_ratio_inv == 0 - ] - if len(remote_block_ids) < len(local_block_ids): - remote_block_ids.append(remote_block_ids[-1] + 1) - # print(f"after {local_block_ids=} {remote_block_ids=}") + print(f"before {local_block_ids=} {remote_block_ids=}") # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1551,12 +1559,6 @@ def _read_blocks( self.xfer_stats.record_failed_notification() return - # Partial prefix cache hit: just read uncomputed blocks. - num_remote_blocks = len(remote_block_ids) - assert num_local_blocks <= num_remote_blocks - if num_local_blocks < num_remote_blocks: - remote_block_ids = remote_block_ids[-num_local_blocks:] - # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] @@ -1571,7 +1573,9 @@ def _read_blocks( if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids, block_size_ratio=block_size_ratio + dst_engine_id, + remote_block_ids, + remain_nblocks=self.remote_remain_blocks, ) local_block_descs_ids = self._get_block_descs_ids( self.engine_id, local_block_ids @@ -1595,13 +1599,10 @@ def _read_blocks( # Get descs ids for the layer. layer_local_desc_ids = self._get_block_descs_ids( - self.engine_id, layer_local_block_ids, layer_idx + dst_engine_id, layer_local_block_ids, layer_idx ) layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, - layer_remote_block_ids, - layer_idx, - block_size_ratio=block_size_ratio, + self.engine_id, layer_remote_block_ids, layer_idx ) local_descs_list.append(layer_local_desc_ids) @@ -1610,6 +1611,12 @@ def _read_blocks( local_block_descs_ids = np.concatenate(local_descs_list) remote_block_descs_ids = np.concatenate(remote_descs_list) + # NOTE(Chendi): Update remote remote_block_descs_ids using local block_size + # print(f"{remote_block_descs_ids=}{local_block_descs_ids=}") + remote_block_descs_ids = self.get_mapped_blocks( + remote_block_descs_ids, block_size_ratio + ) + # print(f"after {remote_block_descs_ids=}") assert len(local_block_descs_ids) == len(remote_block_descs_ids) # Prepare transfer with Nixl. @@ -1643,12 +1650,34 @@ def _read_blocks( self.nixl_wrapper.release_xfer_handle(handle) self._failed_recv_reqs.add(request_id) + def get_mapped_blocks(self, block_ids, block_size_ratio): + """ + Calculates the new set of block IDs by mapping every element + in the (potentially sparse) input array. + """ + if block_ids.size == 0: + return np.array([], dtype=np.int64) + + # 1. Scale all block IDs and truncate (floor) + # (This is the logic from your original code block) + mapped_ids = (block_ids * block_size_ratio).astype(np.int64) + + # 2. Get the unique resulting IDs + unique_blocks = np.unique(mapped_ids) + + # # 3. Filter out 0, as block IDs are 1-based. + # # (This logic matches your Example 1) + # if unique_blocks.size > 0 and unique_blocks[0] == 0: + # unique_blocks = unique_blocks[unique_blocks > 0] + + return unique_blocks + def _get_block_descs_ids( self, engine_id: str, block_ids: list[int], layer_idx: int | None = None, - block_size_ratio: float = 1, + remain_nblocks: list[int] | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -1671,11 +1700,15 @@ def _get_block_descs_ids( region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] + num_blocks = np.full((self.num_regions), num_blocks) + if remain_nblocks is not None: + num_blocks = num_blocks - np.array(remain_nblocks) # Compute the desc ids for each block. - region_ids = region_ids[:, None] block_ids = np.array(block_ids)[None, :] - descs_ids = region_ids * int(num_blocks * block_size_ratio) + block_ids + region_nblocks = region_ids * num_blocks + region_nblocks = region_nblocks[:, None] + descs_ids = region_nblocks + block_ids return descs_ids.flatten() def get_backend_aware_kv_block_len(self, layer_idx: int): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c736e084a38d..ed01cbaf873a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -532,6 +532,7 @@ def insert_blocks_to_device( """Copy blocks from src_cache to dst_cache on GPU.""" _src_cache = src_cache[:, src_block_indices] dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + torch.cuda.synchronize() @classmethod def swap_out_blocks_to_host( @@ -544,6 +545,7 @@ def swap_out_blocks_to_host( """Copy blocks from GPU to host (CPU).""" _src_cache = src_cache[:, src_block_indices] dst_cache[:, dst_block_indices] = _src_cache.cpu() + torch.cuda.synchronize() @classmethod def support_hybrid_kv_cache(cls) -> bool: From 414215ae2b9e66d98aca0ba00e2cee1f12babd58 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 22 Oct 2025 19:44:19 -0700 Subject: [PATCH 03/27] naive post process for HND Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 78 ++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index de4ad6bc243b..a8624d2d4ac8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1316,6 +1316,75 @@ def permute_device_kv(self, block_ids: list[int]): ) cache.index_copy_(0, indices, permuted_blocks) + def blocksize_post_process(self, block_ids: list[int]): + def _process_local_gt_remote(blocks_to_update): + n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] + remote_block_size = int(block_size * self.block_size_ratio) + n_blocks = int(1 / self.block_size_ratio) + # actual permute is to convert + # for local blocksize > remote blocksize + # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens + # local block0 = remote [block0, 1, 2, 3] + # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... + # local is |h0-b0..................|h1-b0..................|... + # permute is to: + # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) + # 2. permute => (H, nblocks, remoteN, D) + # 3. flatten => (H, nblocks, remoteN) + permuted_blocks = ( + blocks_to_update.reshape( + -1, n_blocks, n_kv_heads, remote_block_size, head_size + ) + .permute(0, 2, 1, 3, 4) + .flatten(2, 3) + ) + return permuted_blocks + + def _process_local_lt_remote(blocks_to_update): + n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] + remote_block_size = int(block_size * self.block_size_ratio) + n_blocks = int(1 / self.block_size_ratio) + # actual permute is to convert + # for local blocksize > remote blocksize + # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens + # local block0 = remote [block0, 1, 2, 3] + # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... + # local is |h0-b0..................|h1-b0..................|... + # permute is to: + # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) + # 2. permute => (H, nblocks, remoteN, D) + # 3. flatten => (H, nblocks, remoteN) + permuted_blocks = ( + blocks_to_update.reshape( + -1, n_blocks, n_kv_heads, remote_block_size, head_size + ) + .permute(0, 2, 1, 3, 4) + .flatten(2, 3) + ) + return permuted_blocks + + split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + sample_cache = list(self.device_kv_caches.values())[0][0] + indices = torch.tensor(block_ids, device=sample_cache.device) + fn = ( + _process_local_gt_remote + if self.block_size_ratio < 1 + else _process_local_lt_remote + ) + + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + blocks_to_update = cache.index_select(0, indices) + # because kv_cache is always using original layout NHD as virtual shape + # while stride can be either HND / NHD at initialization. + # we need to firstly get physical view of the tensor + cache.index_copy_( + 0, + indices, + fn(blocks_to_update.permute(0, 2, 1, 3)).permute(0, 2, 1, 3), + ) + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -1339,11 +1408,18 @@ def get_finished(self) -> tuple[set[str], set[str]]: ) # clean up metadata for completed requests + block_ids_for_blocksize_post_process = [] for req_id in done_recving: + meta = self._recving_metadata.pop(req_id) if self.use_host_buffer: - meta = self._recving_metadata.pop(req_id) self.sync_recved_kv_to_device(req_id, meta) + # post processing for heteroblocksize + if self.block_size_ratio < 1 and self.kv_cache_layout == "HND": + block_ids_for_blocksize_post_process += meta.local_block_ids + if len(block_ids_for_blocksize_post_process) > 0: + self.blocksize_post_process(block_ids_for_blocksize_post_process) + # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() while self._reqs_to_send: From e6e3d928db5afdc44533e29edd83a4f8a14797ba Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 23 Oct 2025 15:07:54 -0700 Subject: [PATCH 04/27] Both block_size_ratio < 1 or > 1 works Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 251 ++++++++++-------- vllm/platforms/cuda.py | 2 - 2 files changed, 139 insertions(+), 114 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a8624d2d4ac8..2fde517e6d9c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -626,6 +626,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. self.dst_num_blocks: dict[EngineId, int] = {} + self.dst_block_size_ratio: dict[EngineId, float] = {} self._registered_descs: list[Any] = [] # In progress transfers. @@ -1111,8 +1112,7 @@ def add_remote_agent( # remote_block_lens * remote_block_size = local_block_lens * local_block_size remote_block_size = nixl_agent_meta.block_size block_size_ratio = remote_block_size / self.block_size - block_size_ratio_inv = self.block_size / remote_block_size - self.block_size_ratio = block_size_ratio + self.dst_block_size_ratio[engine_id] = block_size_ratio if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: if ( self.vllm_config.kv_transfer_config is not None @@ -1158,11 +1158,36 @@ def add_remote_agent( "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) + # All attn in vLLM uses blocks starts with 1st(0 is for empty) + # For hetero block size case, block 0 should always remote block_len + # Example: + # block_size_ratio < 1: + # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| + # local: | 0| 1| 8| 12| + # block_size_ratio > 1: + # remote: | 0| 1| 8| 12| + # local: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| + shift_block_0 = remote_block_len if block_size_ratio != 1 else 0 + if block_size_ratio == 1: + num_blocks = nixl_agent_meta.num_blocks + elif block_size_ratio < 1: + num_blocks = self.num_blocks - 1 + else: + num_blocks = int((nixl_agent_meta.num_blocks - 1) * block_size_ratio) + + # For remaining blocks from remote, we might need to drop last few + # Example: + # block_size_ratio < 1: Drop last 3 remote blocks + # remote: |89|90|91|92|93|94|95|96|97|98|99| + # local: | 22| 23| + # block_size_ratio > 1: do not drop + # remote: | 22| 23| + # local: |89|90|91|92|93|94|95|96|97|98|99| # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: - assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + assert self.dst_num_blocks[engine_id] == num_blocks else: - self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + self.dst_num_blocks[engine_id] = num_blocks blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding @@ -1172,54 +1197,31 @@ def add_remote_agent( self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) - # NOTE(Chendi): In case remote and local use different block_size. - local_num_blocks = int(nixl_agent_meta.num_blocks * block_size_ratio) - self.remote_remain_blocks = [] # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + local_block_len_TP = int(self.block_len_per_layer[i] * tp_ratio) rank_offset = ( self.tp_rank % tp_ratio * kv_block_len if not (self.use_mla or is_kv_replicated) else 0 ) - if block_size_ratio < 1: - # we skip block 0 in each layer, block 0 at remote is using smaller size - # in order to correct mapping, we can skip block0 from remote - # example for block_size_ratio = 0.25 - # remote block 0 -> skip - # remote block 1, 2, 3, 4 => local blocl 0 - # remote block 5, 6, 7, 8 => local blocl 1 - # Also, we keep track of last few remote blocks which can't match to - # last local block size and use them in read_block for mapping - # remote/local block_id. - shift_for_remote_block_0 = remote_block_len - remain_nblocks_remote = nixl_agent_meta.num_blocks - int( - local_num_blocks * block_size_ratio_inv - ) - self.remote_remain_blocks.append(remain_nblocks_remote) - local_block_len = int(self.block_len_per_layer[i] * tp_ratio) - for block_id in range(local_num_blocks): - block_offset = block_id * local_block_len + for block_id in range(num_blocks): + block_offset = block_id * local_block_len_TP # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset + shift_for_remote_block_0 + addr = base_addr + block_offset + rank_offset + shift_block_0 # (addr, len, device id) blocks_data.append((addr, kv_block_len, remote_tp_rank)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. - for block_id in range(local_num_blocks): - block_offset = block_id * local_block_len - addr = ( - base_addr - + block_offset - + rank_offset - + shift_for_remote_block_0 - ) - v_addr = addr + local_block_len // 2 + for block_id in range(num_blocks): + block_offset = block_id * local_block_len_TP + addr = base_addr + block_offset + rank_offset + shift_block_0 + v_addr = addr + local_block_len_TP // 2 blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) logger.debug( @@ -1316,21 +1318,21 @@ def permute_device_kv(self, block_ids: list[int]): ) cache.index_copy_(0, indices, permuted_blocks) - def blocksize_post_process(self, block_ids: list[int]): - def _process_local_gt_remote(blocks_to_update): + def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]): + def _process_local_gt_remote(blocks_to_update, block_size_ratio): n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] - remote_block_size = int(block_size * self.block_size_ratio) - n_blocks = int(1 / self.block_size_ratio) + remote_block_size = int(block_size * block_size_ratio) + n_blocks = int(1 / block_size_ratio) # actual permute is to convert # for local blocksize > remote blocksize # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens - # local block0 = remote [block0, 1, 2, 3] + # local block[0] = remote block[0, 1, 2, 3] # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... # local is |h0-b0..................|h1-b0..................|... # permute is to: # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) # 2. permute => (H, nblocks, remoteN, D) - # 3. flatten => (H, nblocks, remoteN) + # 3. flatten => (H, localN, D) permuted_blocks = ( blocks_to_update.reshape( -1, n_blocks, n_kv_heads, remote_block_size, head_size @@ -1340,50 +1342,55 @@ def _process_local_gt_remote(blocks_to_update): ) return permuted_blocks - def _process_local_lt_remote(blocks_to_update): + def _process_local_lt_remote(blocks_to_update, block_size_ratio): n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] - remote_block_size = int(block_size * self.block_size_ratio) - n_blocks = int(1 / self.block_size_ratio) + n_blocks = int(block_size_ratio) # actual permute is to convert - # for local blocksize > remote blocksize - # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens - # local block0 = remote [block0, 1, 2, 3] - # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... - # local is |h0-b0..................|h1-b0..................|... + # for local blocksize < remote blocksize + # ex: local blocksize = 4 tokens, remote blocksize = 16 tokens + # local block[0, 1, 2, 3] = remote block[0] + # remote is |h0-b0..................|h1-b0..................|... + # local is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... # permute is to: - # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) - # 2. permute => (H, nblocks, remoteN, D) - # 3. flatten => (H, nblocks, remoteN) + # 1. view => view remote as (-1, H, n_blocks, localN, D) + # 2. permute => (-1, nblocks, H, localN, D) + # 3. flatten => (-1, H, localN, D) + print(f"{blocks_to_update.shape=} {n_blocks=}") permuted_blocks = ( blocks_to_update.reshape( - -1, n_blocks, n_kv_heads, remote_block_size, head_size + -1, n_kv_heads, n_blocks, block_size, head_size ) .permute(0, 2, 1, 3, 4) - .flatten(2, 3) + .flatten(0, 1) ) return permuted_blocks split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) sample_cache = list(self.device_kv_caches.values())[0][0] - indices = torch.tensor(block_ids, device=sample_cache.device) - fn = ( - _process_local_gt_remote - if self.block_size_ratio < 1 - else _process_local_lt_remote - ) - - for _, cache_or_caches in self.device_kv_caches.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] - for cache in cache_list: - blocks_to_update = cache.index_select(0, indices) - # because kv_cache is always using original layout NHD as virtual shape - # while stride can be either HND / NHD at initialization. - # we need to firstly get physical view of the tensor - cache.index_copy_( - 0, - indices, - fn(blocks_to_update.permute(0, 2, 1, 3)).permute(0, 2, 1, 3), - ) + for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): + if block_size_ratio < 1: + fn = _process_local_gt_remote + else: + fn = _process_local_lt_remote + block_ids_list = [[item for sublist in block_ids_list for item in sublist]] + if len(block_ids_list[0]) == 0: + continue + print(f"{block_ids_list=}") + for block_ids in block_ids_list: + indices = torch.tensor(block_ids, device=sample_cache.device) + + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + blocks_to_update = cache.index_select(0, indices) + # because kv_cache is always using original layout NHD as + # virtual shape while stride can be either HND / NHD at + # initialization. + # we need to firstly get physical view of the tensor + permuted_blocks = fn( + blocks_to_update.permute(0, 2, 1, 3), block_size_ratio + ).permute(0, 2, 1, 3) + cache.index_copy_(0, indices, permuted_blocks) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -1408,17 +1415,21 @@ def get_finished(self) -> tuple[set[str], set[str]]: ) # clean up metadata for completed requests - block_ids_for_blocksize_post_process = [] + block_ids_for_blocksize_post_process: dict[float, list[list[int]]] = {} for req_id in done_recving: meta = self._recving_metadata.pop(req_id) if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) # post processing for heteroblocksize - if self.block_size_ratio < 1 and self.kv_cache_layout == "HND": - block_ids_for_blocksize_post_process += meta.local_block_ids - if len(block_ids_for_blocksize_post_process) > 0: - self.blocksize_post_process(block_ids_for_blocksize_post_process) + block_size_ratio = self.dst_block_size_ratio[meta.remote_engine_id] + if block_size_ratio not in block_ids_for_blocksize_post_process: + block_ids_for_blocksize_post_process[block_size_ratio] = [] + if block_size_ratio != 1 and self.kv_cache_layout == "HND": + block_ids_for_blocksize_post_process[block_size_ratio].append( + meta.local_block_ids + ) + self.blocksize_post_process(block_ids_for_blocksize_post_process) # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() @@ -1549,14 +1560,12 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): # Handshake already completed, start async read xfer. # FIXME(Chendi): should store per engine - self._read_blocks_for_req(req_id, meta, self.block_size_ratio) + self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. while not self._ready_requests.empty(): # FIXME(Chendi): should store per engine - self._read_blocks_for_req( - *self._ready_requests.get_nowait(), self.block_size_ratio - ) + self._read_blocks_for_req(*self._ready_requests.get_nowait()) # Keep around the requests that have been part of a batch. This is # needed because async scheduling pushes the misalignment between the @@ -1578,7 +1587,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): if req_id in self._reqs_to_process: self._reqs_to_send[req_id] = expiration_time - def _read_blocks_for_req(self, req_id: str, meta: ReqMeta, block_size_ratio: float): + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", meta.remote_engine_id, @@ -1589,7 +1598,6 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta, block_size_ratio: flo dst_engine_id=meta.remote_engine_id, local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, - block_size_ratio=block_size_ratio, ) def _read_blocks( @@ -1598,9 +1606,21 @@ def _read_blocks( remote_block_ids: list[int], dst_engine_id: str, request_id: str, - block_size_ratio: float, ): - print(f"before {local_block_ids=} {remote_block_ids=}") + block_size_ratio = self.dst_block_size_ratio[dst_engine_id] + if block_size_ratio != 1: + remote_block_ids = self.get_mapped_blocks( + remote_block_ids, block_size_ratio + ) + # NOTE(Chendi): over assigned local block_id here as temp buffer to receive + # remote data, the padding block_ids will be all zero-after we did permute + # so it will not impact exsiting kv block assign behaviour + if len(local_block_ids) < len(remote_block_ids): + padding_needed = len(remote_block_ids) - len(local_block_ids) + pad_start = local_block_ids[-1] + 1 + local_block_ids += [ + i for i in range(pad_start, padding_needed + pad_start) + ] # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1635,6 +1655,12 @@ def _read_blocks( self.xfer_stats.record_failed_notification() return + # Partial prefix cache hit: just read uncomputed blocks. + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] @@ -1646,15 +1672,16 @@ def _read_blocks( # Get descs ids. local_block_descs_ids: np.ndarray remote_block_descs_ids: np.ndarray + if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( dst_engine_id, remote_block_ids, - remain_nblocks=self.remote_remain_blocks, ) local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids + self.engine_id, + local_block_ids, ) else: # TODO(mgoin): remove this once we have hybrid memory allocator @@ -1675,10 +1702,14 @@ def _read_blocks( # Get descs ids for the layer. layer_local_desc_ids = self._get_block_descs_ids( - dst_engine_id, layer_local_block_ids, layer_idx + dst_engine_id, + layer_local_block_ids, + layer_idx, ) layer_remote_desc_ids = self._get_block_descs_ids( - self.engine_id, layer_remote_block_ids, layer_idx + self.engine_id, + layer_remote_block_ids, + layer_idx, ) local_descs_list.append(layer_local_desc_ids) @@ -1687,12 +1718,7 @@ def _read_blocks( local_block_descs_ids = np.concatenate(local_descs_list) remote_block_descs_ids = np.concatenate(remote_descs_list) - # NOTE(Chendi): Update remote remote_block_descs_ids using local block_size - # print(f"{remote_block_descs_ids=}{local_block_descs_ids=}") - remote_block_descs_ids = self.get_mapped_blocks( - remote_block_descs_ids, block_size_ratio - ) - # print(f"after {remote_block_descs_ids=}") + # print(f"{local_block_descs_ids[:256]=}{remote_block_descs_ids[:256]=}") assert len(local_block_descs_ids) == len(remote_block_descs_ids) # Prepare transfer with Nixl. @@ -1731,29 +1757,29 @@ def get_mapped_blocks(self, block_ids, block_size_ratio): Calculates the new set of block IDs by mapping every element in the (potentially sparse) input array. """ + block_ids = np.array(block_ids) if block_ids.size == 0: return np.array([], dtype=np.int64) - # 1. Scale all block IDs and truncate (floor) - # (This is the logic from your original code block) - mapped_ids = (block_ids * block_size_ratio).astype(np.int64) + block_ids -= 1 + + if block_size_ratio < 1: + mapped_ids = (block_ids * block_size_ratio).astype(np.int64) - # 2. Get the unique resulting IDs - unique_blocks = np.unique(mapped_ids) + return np.unique(mapped_ids) - # # 3. Filter out 0, as block IDs are 1-based. - # # (This logic matches your Example 1) - # if unique_blocks.size > 0 and unique_blocks[0] == 0: - # unique_blocks = unique_blocks[unique_blocks > 0] + elif block_size_ratio > 1: + start_ids = block_ids * block_size_ratio + offsets = np.arange(block_size_ratio) + mapped_2d = start_ids[:, None] + offsets[None, :] - return unique_blocks + return mapped_2d.flatten().astype(np.int64) def _get_block_descs_ids( self, engine_id: str, block_ids: list[int], layer_idx: int | None = None, - remain_nblocks: list[int] | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -1777,8 +1803,6 @@ def _get_block_descs_ids( num_blocks = self.dst_num_blocks[engine_id] num_blocks = np.full((self.num_regions), num_blocks) - if remain_nblocks is not None: - num_blocks = num_blocks - np.array(remain_nblocks) # Compute the desc ids for each block. block_ids = np.array(block_ids)[None, :] @@ -1787,7 +1811,9 @@ def _get_block_descs_ids( descs_ids = region_nblocks + block_ids return descs_ids.flatten() - def get_backend_aware_kv_block_len(self, layer_idx: int): + def get_backend_aware_kv_block_len( + self, layer_idx: int, block_len_per_layer: int | None = None + ): """ Get the block length for one K/V element (K and V have the same size). @@ -1796,11 +1822,12 @@ def get_backend_aware_kv_block_len(self, layer_idx: int): For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ + block_len_per_layer = block_len_per_layer or self.block_len_per_layer[layer_idx] if self._use_flashinfer: # For indexing only half (either just the K or V part). - block_len = self.block_len_per_layer[layer_idx] // 2 + block_len = block_len_per_layer // 2 else: - block_len = self.block_len_per_layer[layer_idx] + block_len = block_len_per_layer return block_len def get_kv_connector_stats(self) -> KVConnectorStats | None: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ed01cbaf873a..c736e084a38d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -532,7 +532,6 @@ def insert_blocks_to_device( """Copy blocks from src_cache to dst_cache on GPU.""" _src_cache = src_cache[:, src_block_indices] dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) - torch.cuda.synchronize() @classmethod def swap_out_blocks_to_host( @@ -545,7 +544,6 @@ def swap_out_blocks_to_host( """Copy blocks from GPU to host (CPU).""" _src_cache = src_cache[:, src_block_indices] dst_cache[:, dst_block_indices] = _src_cache.cpu() - torch.cuda.synchronize() @classmethod def support_hybrid_kv_cache(cls) -> bool: From f0d8b3a5c7fa6d84ecbaf8a513047133fb6540af Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 24 Oct 2025 12:24:33 -0700 Subject: [PATCH 05/27] enable BlockAllocator for prefill/decode block_size_ratio > 1 Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 123 +++++++++++++++--- 1 file changed, 105 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2fde517e6d9c..a6a9684e57eb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import copy +import heapq import logging import math import os @@ -627,6 +628,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # have the same number of blocks. self.dst_num_blocks: dict[EngineId, int] = {} self.dst_block_size_ratio: dict[EngineId, float] = {} + self.block_allocator_for_hetero_blksize: ( + BlockAllocatorForHeteroBlockSize | None + ) = None self._registered_descs: list[Any] = [] # In progress transfers. @@ -926,6 +930,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_len_per_layer) == len(seen_base_addresses) assert self.num_blocks != 0 + self.block_allocator_for_hetero_blksize = BlockAllocatorForHeteroBlockSize( + self.num_blocks + ) + self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -1355,7 +1363,6 @@ def _process_local_lt_remote(blocks_to_update, block_size_ratio): # 1. view => view remote as (-1, H, n_blocks, localN, D) # 2. permute => (-1, nblocks, H, localN, D) # 3. flatten => (-1, H, localN, D) - print(f"{blocks_to_update.shape=} {n_blocks=}") permuted_blocks = ( blocks_to_update.reshape( -1, n_kv_heads, n_blocks, block_size, head_size @@ -1370,13 +1377,21 @@ def _process_local_lt_remote(blocks_to_update, block_size_ratio): for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): if block_size_ratio < 1: fn = _process_local_gt_remote + block_ids_list = [ + [item for sublist in block_ids_list for item in sublist] + ] else: fn = _process_local_lt_remote - block_ids_list = [[item for sublist in block_ids_list for item in sublist]] - if len(block_ids_list[0]) == 0: - continue - print(f"{block_ids_list=}") + assert self.block_allocator_for_hetero_blksize is not None + block_ids_list = [ + self.block_allocator_for_hetero_blksize.padding_block_ids(sublist) + for sublist in block_ids_list + ] + # block_ids_list.sort(key=lambda sublist: sublist[0]) for block_ids in block_ids_list: + if len(block_ids) == 0: + # we don't need to do permute for this req + continue indices = torch.tensor(block_ids, device=sample_cache.device) for _, cache_or_caches in self.device_kv_caches.items(): @@ -1430,6 +1445,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: meta.local_block_ids ) self.blocksize_post_process(block_ids_for_blocksize_post_process) + assert self.block_allocator_for_hetero_blksize is not None + self.block_allocator_for_hetero_blksize.free_block( + block_ids_for_blocksize_post_process + ) # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() @@ -1612,16 +1631,22 @@ def _read_blocks( remote_block_ids = self.get_mapped_blocks( remote_block_ids, block_size_ratio ) - # NOTE(Chendi): over assigned local block_id here as temp buffer to receive - # remote data, the padding block_ids will be all zero-after we did permute - # so it will not impact exsiting kv block assign behaviour - if len(local_block_ids) < len(remote_block_ids): + # FIXME(Chendi): We need find free blocks to pad for local, because + # when we receive remote buffer with bigger blockSize, it might happen + # that local n_blocks scheduled less to match n*local_blksize=remote_blksize + # remote is |h0-b0......|h1-b0......|h3-b0......|h4-b0......| + # local is |h0-b0|h1-b0|h3-b0|h4-b0|no need | + # In order to get entire buffer, we need to assign free blocks to local, + # so we can receive entire buffer from remote, And actually after permute + # done, the free blocks will be all zero and not needed. + if len(local_block_ids) < len(remote_block_ids) and block_size_ratio > 1: + assert self.block_allocator_for_hetero_blksize is not None padding_needed = len(remote_block_ids) - len(local_block_ids) - pad_start = local_block_ids[-1] + 1 - local_block_ids += [ - i for i in range(pad_start, padding_needed + pad_start) - ] - + local_block_ids = ( + self.block_allocator_for_hetero_blksize.padding_block_ids( + local_block_ids, padding_needed + ) + ) # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1718,7 +1743,6 @@ def _read_blocks( local_block_descs_ids = np.concatenate(local_descs_list) remote_block_descs_ids = np.concatenate(remote_descs_list) - # print(f"{local_block_descs_ids[:256]=}{remote_block_descs_ids[:256]=}") assert len(local_block_descs_ids) == len(remote_block_descs_ids) # Prepare transfer with Nixl. @@ -1764,9 +1788,12 @@ def get_mapped_blocks(self, block_ids, block_size_ratio): block_ids -= 1 if block_size_ratio < 1: - mapped_ids = (block_ids * block_size_ratio).astype(np.int64) - - return np.unique(mapped_ids) + mapped_ids = [ + int(i * block_size_ratio) + for i in block_ids + if float(i * block_size_ratio).is_integer() + ] + return np.array(mapped_ids) elif block_size_ratio > 1: start_ids = block_ids * block_size_ratio @@ -1875,6 +1902,66 @@ def shutdown(self): self._registered_descs.clear() +class BlockAllocatorForHeteroBlockSize: + def __init__(self, num_blocks: int): + assert num_blocks > 0, "No Available blocks" + self.available_block_ids: list[int] = [-id for id in range(num_blocks)] + + heapq.heapify(self.available_block_ids) + + self.allocated_blocks: set[int] = set() + self.padding_cache: dict[int, list[int]] = {} + + def padding_block_ids( + self, block_ids: list[int], to_pad_len: int = -1 + ) -> list[int]: + if to_pad_len == 0: + return list(block_ids) + + block_ids_tuple = tuple(block_ids) + key = hash(block_ids_tuple) + + if key in self.padding_cache: + padding_blocks = self.padding_cache[key] + + return block_ids + padding_blocks + + # Check for available blocks + if len(self.available_block_ids) < to_pad_len: + raise ValueError( + f"Not enough available blocks for hash {key}. " + f"Requested {to_pad_len} padding blocks, " + f"but only {len(self.available_block_ids)} are available." + ) + + # Allocate new blocks + padding_blocks = [] + for _ in range(to_pad_len): + negative_block_id = heapq.heappop(self.available_block_ids) + new_block = -negative_block_id + + self.allocated_blocks.add(new_block) + padding_blocks.append(new_block) + + self.padding_cache[key] = padding_blocks + + return block_ids + padding_blocks + + def free_block(self, block_ids_dict: dict[float, list[list[int]]]): + block_ids_list = [i for sublist in block_ids_dict.values() for i in sublist] + for block_ids in block_ids_list: + block_ids_tuple = tuple(block_ids) + key = hash(block_ids_tuple) + if key in self.padding_cache: + padding_blocks = self.padding_cache[key] + for block_id in padding_blocks: + if block_id not in self.allocated_blocks: + continue + + self.allocated_blocks.remove(block_id) + heapq.heappush(self.available_block_ids, -block_id) + + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" From 669ea19937fce1f9b0edaa0c8f72a49de4569b2b Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 24 Oct 2025 15:08:36 -0700 Subject: [PATCH 06/27] Tested both prefill/decode block_size ratio > 1 and < 1 accuracy Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a6a9684e57eb..7ddae3385f2a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1631,6 +1631,12 @@ def _read_blocks( remote_block_ids = self.get_mapped_blocks( remote_block_ids, block_size_ratio ) + # FIXME(Chendi): This is not right, remote with small blocksize + # block id will not be contiguous, in that case, we should always + # copy with small block size, but that will lead to local nixl_register + # using remote block_len, need to double think on how to. + if len(local_block_ids) > len(remote_block_ids): + local_block_ids = local_block_ids[: len(remote_block_ids)] # FIXME(Chendi): We need find free blocks to pad for local, because # when we receive remote buffer with bigger blockSize, it might happen # that local n_blocks scheduled less to match n*local_blksize=remote_blksize @@ -1788,12 +1794,11 @@ def get_mapped_blocks(self, block_ids, block_size_ratio): block_ids -= 1 if block_size_ratio < 1: - mapped_ids = [ - int(i * block_size_ratio) - for i in block_ids - if float(i * block_size_ratio).is_integer() - ] - return np.array(mapped_ids) + start_id = block_ids[0] + shift = start_id - int(int(start_id * block_size_ratio) / block_size_ratio) + block_ids += shift + mapped_ids = np.unique((block_ids * block_size_ratio).astype(np.int64)) + return mapped_ids elif block_size_ratio > 1: start_ids = block_ids * block_size_ratio From ce273588ad0b2bf1c94bc26f30eb59a10cdb659b Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 30 Oct 2025 15:18:59 -0700 Subject: [PATCH 07/27] Enable a second local xfer handler register Signed-off-by: Chendi Xue --- .../nixl_integration/run_accuracy_test.sh | 4 + .../kv_connector/v1/nixl_connector.py | 134 +++++++++++++----- 2 files changed, 102 insertions(+), 36 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index a9817313cf02..ebc8575e5b39 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -49,6 +49,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} +PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16} +DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -136,6 +138,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ + --block-size ${PREFILL_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" @@ -177,6 +180,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ + --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --kv-transfer-config '$KV_CONFIG'" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7802092ae435..bf0146613c9b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -699,6 +699,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 + self.src_xfer_side_handles: dict[int, int] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[EngineId, int] = {} @@ -863,6 +864,12 @@ def _nixl_handshake( remote_agent_name = self.add_remote_agent( metadata, p_remote_rank, remote_tp_size ) + if metadata.block_size < self.block_size: + # when prefill with small block_size, we need to init a + # new handler with same block_len to match + self.src_xfer_side_handles[metadata.block_size] = ( + self.register_local_xfer_handler(metadata.block_size) + ) setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", @@ -1071,8 +1078,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.num_regions *= 2 # Register local/src descr for NIXL xfer. + self.seen_base_addresses = seen_base_addresses blocks_data = [] - for i, base_addr in enumerate(seen_base_addresses): + for i, base_addr in enumerate(self.seen_base_addresses): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We @@ -1108,6 +1116,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs ) + self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. @@ -1161,6 +1170,54 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_stop_event = stop_event ready_event.wait() # Wait for listener ZMQ socket to be ready. + def register_local_xfer_handler( + self, + block_size: int, + ) -> int: + """ + Only serve for use case when local is decode and local block size is larger + than prefill block size. In that case, we need to re-register local xfer addr + using remote block_len. + therwise it can do one on one remote_block <-> local_block transfer. + """ + block_size_ratio = self.block_size // block_size + blocks_data = [] + for i, base_addr in enumerate(self.seen_base_addresses): + # The new block_len is using prefill block_len; + # and num_blocks is multiple with N + kv_block_len = ( + self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio + ) + block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio + num_blocks = self.num_blocks * block_size_ratio + for block_id in range(num_blocks): + block_offset = block_id * block_len_per_layer + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, kv_block_len, self.device_id)) + + if self._use_flashinfer: + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(num_blocks): + block_offset = block_id * block_len_per_layer + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, kv_block_len, self.device_id)) + logger.debug( + "Created %s blocks for src engine %s and rank %s on device id %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + self.device_id, + ) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + # NIXL_INIT_AGENT to be used for preparations of local descs. + return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + def add_remote_agent( self, nixl_agent_meta: NixlAgentMetadata, @@ -1233,8 +1290,9 @@ def add_remote_agent( # For hetero block size case, block 0 should always remote block_len # Example: # block_size_ratio < 1: - # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| - # local: | 0| 1| 8| 12| + # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| + # local origin:| 0| 1| 8| 12| + # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| # block_size_ratio > 1: # remote: | 0| 1| 8| 12| # local: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| @@ -1242,11 +1300,9 @@ def add_remote_agent( remote_block_size = nixl_agent_meta.block_size block_size_ratio = remote_block_size / self.block_size self.dst_block_size_ratio[engine_id] = block_size_ratio - shift_block_0 = remote_block_len if block_size_ratio != 1 else 0 - if block_size_ratio == 1: + shift_block_0 = remote_block_len if block_size_ratio > 1 else 0 + if block_size_ratio <= 1: num_blocks = nixl_agent_meta.num_blocks - elif block_size_ratio < 1: - num_blocks = self.num_blocks - 1 else: num_blocks = int((nixl_agent_meta.num_blocks - 1) * block_size_ratio) @@ -1274,13 +1330,18 @@ def add_remote_agent( # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - local_block_len_TP = int(self.block_len_per_layer[i] * tp_ratio) + if block_size_ratio <= 1: + # using remote kv_block_len as transfer unit + kv_block_len = int(kv_block_len * block_size_ratio) + block_len_per_layer = nixl_agent_meta.block_lens[i] + else: + block_len_per_layer = int(self.block_len_per_layer[i] * tp_ratio) rank_offset = ( self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 ) for block_id in range(num_blocks): - block_offset = block_id * local_block_len_TP + block_offset = block_id * block_len_per_layer # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to # self.block_len == remote_block_len//tp_ratio bytes. @@ -1291,9 +1352,9 @@ def add_remote_agent( if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. for block_id in range(num_blocks): - block_offset = block_id * local_block_len_TP + block_offset = block_id * block_len_per_layer addr = base_addr + block_offset + rank_offset + shift_block_0 - v_addr = addr + local_block_len_TP // 2 + v_addr = addr + block_len_per_layer // 2 blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) logger.debug( @@ -1758,16 +1819,11 @@ def _read_blocks( request_id: str, ): block_size_ratio = self.dst_block_size_ratio[dst_engine_id] - if block_size_ratio != 1: + block_size_ratio_inv = None + if block_size_ratio > 1: remote_block_ids = self.get_mapped_blocks( - remote_block_ids, block_size_ratio + np.asarray(remote_block_ids) - 1, block_size_ratio ) - # FIXME(Chendi): This is not right, remote with small blocksize - # block id will not be contiguous, in that case, we should always - # copy with small block size, but that will lead to local nixl_register - # using remote block_len, need to double think on how to. - if len(local_block_ids) > len(remote_block_ids): - local_block_ids = local_block_ids[: len(remote_block_ids)] # FIXME(Chendi): We need find free blocks to pad for local, because # when we receive remote buffer with bigger blockSize, it might happen # that local n_blocks scheduled less to match n*local_blksize=remote_blksize @@ -1776,7 +1832,7 @@ def _read_blocks( # In order to get entire buffer, we need to assign free blocks to local, # so we can receive entire buffer from remote, And actually after permute # done, the free blocks will be all zero and not needed. - if len(local_block_ids) < len(remote_block_ids) and block_size_ratio > 1: + if len(local_block_ids) < len(remote_block_ids): assert self.block_allocator_for_hetero_blksize is not None padding_needed = len(remote_block_ids) - len(local_block_ids) local_block_ids = ( @@ -1784,6 +1840,13 @@ def _read_blocks( local_block_ids, padding_needed ) ) + elif block_size_ratio < 1: + block_size_ratio_inv = int(1 / block_size_ratio) + local_block_ids = self.get_mapped_blocks( + np.asarray(local_block_ids), block_size_ratio_inv + ) + if len(local_block_ids) > len(remote_block_ids): + local_block_ids = local_block_ids[: len(remote_block_ids)] # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1826,7 +1889,12 @@ def _read_blocks( remote_block_ids = remote_block_ids[-num_local_blocks:] # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle + block_size = ( + self.block_size + if block_size_ratio >= 1 + else int(self.block_size * block_size_ratio) + ) + local_xfer_side_handle = self.src_xfer_side_handles[block_size] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from @@ -1846,6 +1914,7 @@ def _read_blocks( local_block_descs_ids = self._get_block_descs_ids( self.engine_id, local_block_ids, + block_size_ratio_inv=block_size_ratio_inv, ) else: # TODO(mgoin): remove this once we have hybrid memory allocator @@ -1874,6 +1943,7 @@ def _read_blocks( self.engine_id, layer_remote_block_ids, layer_idx, + block_size_ratio_inv=block_size_ratio_inv, ) local_descs_list.append(layer_local_desc_ids) @@ -1920,31 +1990,21 @@ def get_mapped_blocks(self, block_ids, block_size_ratio): Calculates the new set of block IDs by mapping every element in the (potentially sparse) input array. """ - block_ids = np.array(block_ids) if block_ids.size == 0: return np.array([], dtype=np.int64) - block_ids -= 1 - - if block_size_ratio < 1: - start_id = block_ids[0] - shift = start_id - int(int(start_id * block_size_ratio) / block_size_ratio) - block_ids += shift - mapped_ids = np.unique((block_ids * block_size_ratio).astype(np.int64)) - return mapped_ids - - elif block_size_ratio > 1: - start_ids = block_ids * block_size_ratio - offsets = np.arange(block_size_ratio) - mapped_2d = start_ids[:, None] + offsets[None, :] + start_ids = block_ids * block_size_ratio + offsets = np.arange(block_size_ratio) + mapped_2d = start_ids[:, None] + offsets[None, :] - return mapped_2d.flatten().astype(np.int64) + return mapped_2d.flatten().astype(np.int64) def _get_block_descs_ids( self, engine_id: str, block_ids: list[int], layer_idx: int | None = None, + block_size_ratio_inv: int | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -1967,6 +2027,8 @@ def _get_block_descs_ids( region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] + if block_size_ratio_inv is not None: + num_blocks = num_blocks * block_size_ratio_inv num_blocks = np.full((self.num_regions), num_blocks) # Compute the desc ids for each block. From 59a32446e01437c950914a99586b28af30dcc148 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 30 Oct 2025 17:11:05 -0700 Subject: [PATCH 08/27] remove FIXME Signed-off-by: Chendi Xue --- .../distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1f8117d3185c..be100c744465 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1787,12 +1787,10 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): continue # Handshake already completed, start async read xfer. - # FIXME(Chendi): should store per engine self._read_blocks_for_req(req_id, meta) # Start transfers for requests whose handshakes have now finished. while not self._ready_requests.empty(): - # FIXME(Chendi): should store per engine self._read_blocks_for_req(*self._ready_requests.get_nowait()) # Keep around the requests that have been part of a batch. This is @@ -1841,7 +1839,7 @@ def _read_blocks( remote_block_ids = self.get_mapped_blocks( np.asarray(remote_block_ids) - 1, block_size_ratio ) - # FIXME(Chendi): We need find free blocks to pad for local, because + # NOTE: We need find free blocks to pad for local, because # when we receive remote buffer with bigger blockSize, it might happen # that local n_blocks scheduled less to match n*local_blksize=remote_blksize # remote is |h0-b0......|h1-b0......|h3-b0......|h4-b0......| From 4caca02a1f02ff3adb0034989d02b00034c097f3 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 30 Oct 2025 17:13:37 -0700 Subject: [PATCH 09/27] remove duplicate func Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 43 +------------------ 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index be100c744465..7d70e0cd7ec7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1096,43 +1096,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register local/src descr for NIXL xfer. self.seen_base_addresses = seen_base_addresses - blocks_data = [] - for i, base_addr in enumerate(self.seen_base_addresses): - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - # NOTE With heter-TP, more blocks are prepared than what are - # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We - # could create fewer, but then _get_block_descs_ids needs to - # select agent_meta.num_blocks instead of self.num_blocks for - # local descr, and that makes handling regular flow less clean. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len_per_layer[i] - addr = base_addr + block_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.device_id)) + self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size) - if self._use_flashinfer: - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len_per_layer[i] - addr = base_addr + block_offset - # Register addresses for V cache (K registered first). - v_addr = addr + kv_block_len - blocks_data.append((v_addr, kv_block_len, self.device_id)) - logger.debug( - "Created %s blocks for src engine %s and rank %s on device id %s", - len(blocks_data), - self.engine_id, - self.tp_rank, - self.device_id, - ) - - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - # NIXL_INIT_AGENT to be used for preparations of local descs. - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs - ) self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle # TODO(mgoin): Hybrid memory allocator is currently disabled for @@ -1191,12 +1156,6 @@ def register_local_xfer_handler( self, block_size: int, ) -> int: - """ - Only serve for use case when local is decode and local block size is larger - than prefill block size. In that case, we need to re-register local xfer addr - using remote block_len. - therwise it can do one on one remote_block <-> local_block transfer. - """ block_size_ratio = self.block_size // block_size blocks_data = [] for i, base_addr in enumerate(self.seen_base_addresses): From 90c26f449a607fb0caaef40440070fe3c076a327 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 31 Oct 2025 12:02:38 -0700 Subject: [PATCH 10/27] Limit nP > nD buffer blocks length and print warning when overlapping Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 55 ++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7d70e0cd7ec7..7d3ce1040eef 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -887,6 +887,18 @@ def _nixl_handshake( self.src_xfer_side_handles[metadata.block_size] = ( self.register_local_xfer_handler(metadata.block_size) ) + elif metadata.block_size > self.block_size: + # lets steal some blocks at tail as temp block to cache large block_size + assigned_num_blocks = self.kv_transfer_config.get_from_extra_config( + "num_buffer_blocks_for_hetero_block_size", 200 + ) + + self.block_allocator_for_hetero_blksize = ( + BlockAllocatorForHeteroBlockSize( + total_num_blocks=self.num_blocks, + assigned_num_blocks=assigned_num_blocks, + ) + ) setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", @@ -1065,10 +1077,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_len_per_layer) == len(seen_base_addresses) assert self.num_blocks != 0 - self.block_allocator_for_hetero_blksize = BlockAllocatorForHeteroBlockSize( - self.num_blocks - ) - self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -1620,10 +1628,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: meta.local_block_ids ) self.blocksize_post_process(block_ids_for_blocksize_post_process) - assert self.block_allocator_for_hetero_blksize is not None - self.block_allocator_for_hetero_blksize.free_block( - block_ids_for_blocksize_post_process - ) + if self.block_allocator_for_hetero_blksize is not None: + self.block_allocator_for_hetero_blksize.free_block( + block_ids_for_blocksize_post_process + ) if len(block_ids_to_permute) > 0: self.permute_device_kv(block_ids_to_permute) @@ -1798,6 +1806,7 @@ def _read_blocks( remote_block_ids = self.get_mapped_blocks( np.asarray(remote_block_ids) - 1, block_size_ratio ) + # NOTE: We need find free blocks to pad for local, because # when we receive remote buffer with bigger blockSize, it might happen # that local n_blocks scheduled less to match n*local_blksize=remote_blksize @@ -1806,6 +1815,18 @@ def _read_blocks( # In order to get entire buffer, we need to assign free blocks to local, # so we can receive entire buffer from remote, And actually after permute # done, the free blocks will be all zero and not needed. + if self.block_allocator_for_hetero_blksize is not None: + buffer_blocks = ( + self.block_allocator_for_hetero_blksize.assigned_num_blocks + ) + if max(local_block_ids) >= self.num_blocks - buffer_blocks: + logger.warning( + "assigned block_id %s is overlapping with buffer " + "block_ids range (%d, %d), accuracy will gets impact.", + str(local_block_ids), + (self.num_blocks - buffer_blocks), + self.num_blocks, + ) if len(local_block_ids) < len(remote_block_ids): assert self.block_allocator_for_hetero_blksize is not None padding_needed = len(remote_block_ids) - len(local_block_ids) @@ -2085,9 +2106,16 @@ def shutdown(self): class BlockAllocatorForHeteroBlockSize: - def __init__(self, num_blocks: int): - assert num_blocks > 0, "No Available blocks" - self.available_block_ids: list[int] = [-id for id in range(num_blocks)] + def __init__(self, total_num_blocks: int, assigned_num_blocks: int): + assert total_num_blocks > 0, "No Available blocks" + self.available_block_ids: list[int] = [-id for id in range(total_num_blocks)] + self.assigned_num_blocks = assigned_num_blocks + logger.info( + "BlockAllocatorForHeteroBlockSize is initialized, using %d - %d blocks " + "as buffer blocks for temporary remote larger block_size cache", + (total_num_blocks - assigned_num_blocks), + total_num_blocks, + ) heapq.heapify(self.available_block_ids) @@ -2109,11 +2137,12 @@ def padding_block_ids( return block_ids + padding_blocks # Check for available blocks - if len(self.available_block_ids) < to_pad_len: + if self.assigned_num_blocks - len(self.allocated_blocks) < to_pad_len: + avaliable = self.assigned_num_blocks - len(self.allocated_blocks) raise ValueError( f"Not enough available blocks for hash {key}. " f"Requested {to_pad_len} padding blocks, " - f"but only {len(self.available_block_ids)} are available." + f"but only {avaliable} are available." ) # Allocate new blocks From 087221409cf29267faaa6998cf7f6d7bc7cefc40 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 31 Oct 2025 12:18:04 -0700 Subject: [PATCH 11/27] small fix on default buffer setting Signed-off-by: Chendi Xue --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d264cbd63a0a..80a599b47c4e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -955,7 +955,7 @@ def _nixl_handshake( elif metadata.block_size > self.block_size: # lets steal some blocks at tail as temp block to cache large block_size assigned_num_blocks = self.kv_transfer_config.get_from_extra_config( - "num_buffer_blocks_for_hetero_block_size", 200 + "num_buffer_blocks_for_hetero_block_size", 1000 ) self.block_allocator_for_hetero_blksize = ( From 8ebde3f0ec39ebf98530781f76ae12bf1a1a3b48 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 31 Oct 2025 14:51:51 -0700 Subject: [PATCH 12/27] Fix for nP > nD + TP_ratio != 1 scenraio Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 69 ++++++++++++------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 80a599b47c4e..61907d985f33 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1339,19 +1339,22 @@ def add_remote_agent( block_size_ratio = remote_block_size / self.block_size self.dst_block_size_ratio[engine_id] = block_size_ratio shift_block_0 = remote_block_len if block_size_ratio > 1 else 0 - if block_size_ratio <= 1: - num_blocks = nixl_agent_meta.num_blocks - else: - num_blocks = int((nixl_agent_meta.num_blocks - 1) * block_size_ratio) + num_blocks = nixl_agent_meta.num_blocks + expanded_num_blocks = math.ceil(block_size_ratio) + if block_size_ratio > 1: + num_blocks -= 1 + + # when block_size_ratio > 1, one prefill block is n decode_block + # loop n times of decode block_len to match to prefill if engine_id not in self.dst_num_blocks: - self.dst_num_blocks[engine_id] = num_blocks + self.dst_num_blocks[engine_id] = num_blocks * expanded_num_blocks # Keep track of remote agent kv caches base addresses. self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr self._validate_remote_agent_handshake( - nixl_agent_meta, remote_tp_size, num_blocks + nixl_agent_meta, remote_tp_size, num_blocks * expanded_num_blocks ) # Number of D TP workers reading from a single P TP worker. This is @@ -1368,34 +1371,48 @@ def add_remote_agent( # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + remote_kv_block_len = int(kv_block_len * block_size_ratio) if block_size_ratio <= 1: # using remote kv_block_len as transfer unit - kv_block_len = int(kv_block_len * block_size_ratio) - block_len_per_layer = nixl_agent_meta.block_lens[i] - else: - block_len_per_layer = int(self.block_len_per_layer[i] * tp_ratio) + kv_block_len = remote_kv_block_len rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 + self.tp_rank % tp_ratio * remote_kv_block_len + if not replicates_kv_cache + else 0 ) - for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset + shift_block_0 - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) + block_offset = block_id * nixl_agent_meta.block_lens[i] + for sub_block_id in range(expanded_num_blocks): + expanded_block_offset = ( + block_offset + sub_block_id * self.block_len_per_layer[i] + ) + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = ( + base_addr + expanded_block_offset + rank_offset + shift_block_0 + ) + # (addr, len, device id) + blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. for block_id in range(num_blocks): - block_offset = block_id * block_len_per_layer - addr = base_addr + block_offset + rank_offset + shift_block_0 - v_addr = addr + block_len_per_layer // 2 - blocks_data.append( - (v_addr, kv_block_len, nixl_agent_meta.device_id) - ) + block_offset = block_id * nixl_agent_meta.block_lens[i] + for sub_block_id in range(expanded_num_blocks): + expanded_block_offset = ( + block_offset + sub_block_id * self.block_len_per_layer[i] + ) + addr = ( + base_addr + + expanded_block_offset + + rank_offset + + shift_block_0 + ) + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + blocks_data.append( + (v_addr, kv_block_len, nixl_agent_meta.device_id) + ) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", @@ -1610,6 +1627,8 @@ def _process_local_lt_remote(blocks_to_update, block_size_ratio): split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) sample_cache = list(self.device_kv_caches.values())[0][0] for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): + if len(block_ids_list) == 0: + continue if block_size_ratio < 1: fn = _process_local_gt_remote block_ids_list = [ From c9dfb5180abed91dc0e67f62831a86da51c0c5e4 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Mon, 3 Nov 2025 17:24:20 -0800 Subject: [PATCH 13/27] clean up and remove nnecessary shift Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 61907d985f33..ff20a4e6bcef 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1332,29 +1332,29 @@ def add_remote_agent( # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| # block_size_ratio > 1: - # remote: | 0| 1| 8| 12| - # local: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| - remote_block_len = nixl_agent_meta.block_lens[0] + # remote: | 0| 1| 8| 12| + # new remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| + # local: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| remote_block_size = nixl_agent_meta.block_size block_size_ratio = remote_block_size / self.block_size self.dst_block_size_ratio[engine_id] = block_size_ratio - shift_block_0 = remote_block_len if block_size_ratio > 1 else 0 - num_blocks = nixl_agent_meta.num_blocks expanded_num_blocks = math.ceil(block_size_ratio) - if block_size_ratio > 1: - num_blocks -= 1 # when block_size_ratio > 1, one prefill block is n decode_block # loop n times of decode block_len to match to prefill if engine_id not in self.dst_num_blocks: - self.dst_num_blocks[engine_id] = num_blocks * expanded_num_blocks + self.dst_num_blocks[engine_id] = ( + nixl_agent_meta.num_blocks * expanded_num_blocks + ) # Keep track of remote agent kv caches base addresses. self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr self._validate_remote_agent_handshake( - nixl_agent_meta, remote_tp_size, num_blocks * expanded_num_blocks + nixl_agent_meta, + remote_tp_size, + nixl_agent_meta.num_blocks * expanded_num_blocks, ) # Number of D TP workers reading from a single P TP worker. This is @@ -1380,7 +1380,7 @@ def add_remote_agent( if not replicates_kv_cache else 0 ) - for block_id in range(num_blocks): + for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] for sub_block_id in range(expanded_num_blocks): expanded_block_offset = ( @@ -1389,26 +1389,19 @@ def add_remote_agent( # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to # self.block_len == remote_block_len//tp_ratio bytes. - addr = ( - base_addr + expanded_block_offset + rank_offset + shift_block_0 - ) + addr = base_addr + expanded_block_offset + rank_offset # (addr, len, device id) blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. - for block_id in range(num_blocks): + for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] for sub_block_id in range(expanded_num_blocks): expanded_block_offset = ( block_offset + sub_block_id * self.block_len_per_layer[i] ) - addr = ( - base_addr - + expanded_block_offset - + rank_offset - + shift_block_0 - ) + addr = base_addr + expanded_block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 blocks_data.append( (v_addr, kv_block_len, nixl_agent_meta.device_id) @@ -1879,7 +1872,7 @@ def _read_blocks( block_size_ratio_inv = None if block_size_ratio > 1: remote_block_ids = self.get_mapped_blocks( - np.asarray(remote_block_ids) - 1, block_size_ratio + np.asarray(remote_block_ids), block_size_ratio ) # NOTE: We need find free blocks to pad for local, because From 5eb54e181f406ec0ace405aaea67dc04817d9e47 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 4 Nov 2025 15:24:44 -0800 Subject: [PATCH 14/27] remove nP > nD path, will do it in seperate PR Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 212 ++---------------- 1 file changed, 24 insertions(+), 188 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ff20a4e6bcef..68f499002316 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import copy -import heapq import logging import math import os @@ -825,9 +824,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # have the same number of blocks. self.dst_num_blocks: dict[EngineId, int] = {} self.dst_block_size_ratio: dict[EngineId, float] = {} - self.block_allocator_for_hetero_blksize: ( - BlockAllocatorForHeteroBlockSize | None - ) = None self._registered_descs: list[Any] = [] # In progress transfers. @@ -946,24 +942,16 @@ def _nixl_handshake( remote_agent_name = self.add_remote_agent( metadata, p_remote_rank, remote_tp_size ) + assert metadata.block_size <= self.block_size, ( + "nP > nD is not supported yet." + ) if metadata.block_size < self.block_size: # when prefill with small block_size, we need to init a # new handler with same block_len to match self.src_xfer_side_handles[metadata.block_size] = ( self.register_local_xfer_handler(metadata.block_size) ) - elif metadata.block_size > self.block_size: - # lets steal some blocks at tail as temp block to cache large block_size - assigned_num_blocks = self.kv_transfer_config.get_from_extra_config( - "num_buffer_blocks_for_hetero_block_size", 1000 - ) - self.block_allocator_for_hetero_blksize = ( - BlockAllocatorForHeteroBlockSize( - total_num_blocks=self.num_blocks, - assigned_num_blocks=assigned_num_blocks, - ) - ) setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", @@ -1331,22 +1319,14 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - # block_size_ratio > 1: - # remote: | 0| 1| 8| 12| - # new remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - # local: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| remote_block_size = nixl_agent_meta.block_size block_size_ratio = remote_block_size / self.block_size self.dst_block_size_ratio[engine_id] = block_size_ratio - expanded_num_blocks = math.ceil(block_size_ratio) - # when block_size_ratio > 1, one prefill block is n decode_block # loop n times of decode block_len to match to prefill if engine_id not in self.dst_num_blocks: - self.dst_num_blocks[engine_id] = ( - nixl_agent_meta.num_blocks * expanded_num_blocks - ) + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr @@ -1354,7 +1334,6 @@ def add_remote_agent( self._validate_remote_agent_handshake( nixl_agent_meta, remote_tp_size, - nixl_agent_meta.num_blocks * expanded_num_blocks, ) # Number of D TP workers reading from a single P TP worker. This is @@ -1379,33 +1358,27 @@ def add_remote_agent( self.tp_rank % tp_ratio * remote_kv_block_len if not replicates_kv_cache else 0 + if not replicates_kv_cache + else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] - for sub_block_id in range(expanded_num_blocks): - expanded_block_offset = ( - block_offset + sub_block_id * self.block_len_per_layer[i] - ) - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + expanded_block_offset + rank_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] - for sub_block_id in range(expanded_num_blocks): - expanded_block_offset = ( - block_offset + sub_block_id * self.block_len_per_layer[i] - ) - addr = base_addr + expanded_block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append( - (v_addr, kv_block_len, nixl_agent_meta.device_id) - ) + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + blocks_data.append( + (v_addr, kv_block_len, nixl_agent_meta.device_id) + ) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", @@ -1424,7 +1397,7 @@ def add_remote_agent( return remote_agent_name def _validate_remote_agent_handshake( - self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int, num_blocks: int + self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int ): """ Validate the remote agent handshake metadata ensuring the @@ -1489,7 +1462,7 @@ def _validate_remote_agent_handshake( ) # TP workers have same #blocks. - assert self.dst_num_blocks[remote_engine_id] == num_blocks + assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) @@ -1595,46 +1568,17 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): ) return permuted_blocks - def _process_local_lt_remote(blocks_to_update, block_size_ratio): - n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] - n_blocks = int(block_size_ratio) - # actual permute is to convert - # for local blocksize < remote blocksize - # ex: local blocksize = 4 tokens, remote blocksize = 16 tokens - # local block[0, 1, 2, 3] = remote block[0] - # remote is |h0-b0..................|h1-b0..................|... - # local is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... - # permute is to: - # 1. view => view remote as (-1, H, n_blocks, localN, D) - # 2. permute => (-1, nblocks, H, localN, D) - # 3. flatten => (-1, H, localN, D) - permuted_blocks = ( - blocks_to_update.reshape( - -1, n_kv_heads, n_blocks, block_size, head_size - ) - .permute(0, 2, 1, 3, 4) - .flatten(0, 1) - ) - return permuted_blocks - split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) sample_cache = list(self.device_kv_caches.values())[0][0] for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): if len(block_ids_list) == 0: continue + assert block_size_ratio < 1, "nP > nD is not supported yet." if block_size_ratio < 1: fn = _process_local_gt_remote block_ids_list = [ [item for sublist in block_ids_list for item in sublist] ] - else: - fn = _process_local_lt_remote - assert self.block_allocator_for_hetero_blksize is not None - block_ids_list = [ - self.block_allocator_for_hetero_blksize.padding_block_ids(sublist) - for sublist in block_ids_list - ] - # block_ids_list.sort(key=lambda sublist: sublist[0]) for block_ids in block_ids_list: if len(block_ids) == 0: # we don't need to do permute for this req @@ -1696,10 +1640,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: meta.local_block_ids ) self.blocksize_post_process(block_ids_for_blocksize_post_process) - if self.block_allocator_for_hetero_blksize is not None: - self.block_allocator_for_hetero_blksize.free_block( - block_ids_for_blocksize_post_process - ) if len(block_ids_to_permute) > 0: self.permute_device_kv(block_ids_to_permute) @@ -1870,40 +1810,7 @@ def _read_blocks( ): block_size_ratio = self.dst_block_size_ratio[dst_engine_id] block_size_ratio_inv = None - if block_size_ratio > 1: - remote_block_ids = self.get_mapped_blocks( - np.asarray(remote_block_ids), block_size_ratio - ) - - # NOTE: We need find free blocks to pad for local, because - # when we receive remote buffer with bigger blockSize, it might happen - # that local n_blocks scheduled less to match n*local_blksize=remote_blksize - # remote is |h0-b0......|h1-b0......|h3-b0......|h4-b0......| - # local is |h0-b0|h1-b0|h3-b0|h4-b0|no need | - # In order to get entire buffer, we need to assign free blocks to local, - # so we can receive entire buffer from remote, And actually after permute - # done, the free blocks will be all zero and not needed. - if self.block_allocator_for_hetero_blksize is not None: - buffer_blocks = ( - self.block_allocator_for_hetero_blksize.assigned_num_blocks - ) - if max(local_block_ids) >= self.num_blocks - buffer_blocks: - logger.warning( - "assigned block_id %s is overlapping with buffer " - "block_ids range (%d, %d), accuracy will gets impact.", - str(local_block_ids), - (self.num_blocks - buffer_blocks), - self.num_blocks, - ) - if len(local_block_ids) < len(remote_block_ids): - assert self.block_allocator_for_hetero_blksize is not None - padding_needed = len(remote_block_ids) - len(local_block_ids) - local_block_ids = ( - self.block_allocator_for_hetero_blksize.padding_block_ids( - local_block_ids, padding_needed - ) - ) - elif block_size_ratio < 1: + if block_size_ratio < 1: block_size_ratio_inv = int(1 / block_size_ratio) local_block_ids = self.get_mapped_blocks( np.asarray(local_block_ids), block_size_ratio_inv @@ -2101,9 +2008,7 @@ def _get_block_descs_ids( descs_ids = region_nblocks + block_ids return descs_ids.flatten() - def get_backend_aware_kv_block_len( - self, layer_idx: int, block_len_per_layer: int | None = None - ): + def get_backend_aware_kv_block_len(self, layer_idx: int): """ Get the block length for one K/V element (K and V have the same size). @@ -2112,12 +2017,11 @@ def get_backend_aware_kv_block_len( For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ - block_len_per_layer = block_len_per_layer or self.block_len_per_layer[layer_idx] if self._use_flashinfer: # For indexing only half (either just the K or V part). - block_len = block_len_per_layer // 2 + block_len = self.block_len_per_layer[layer_idx] // 2 else: - block_len = block_len_per_layer + block_len = self.block_len_per_layer[layer_idx] return block_len def get_kv_connector_stats(self) -> KVConnectorStats | None: @@ -2165,74 +2069,6 @@ def shutdown(self): self._registered_descs.clear() -class BlockAllocatorForHeteroBlockSize: - def __init__(self, total_num_blocks: int, assigned_num_blocks: int): - assert total_num_blocks > 0, "No Available blocks" - self.available_block_ids: list[int] = [-id for id in range(total_num_blocks)] - self.assigned_num_blocks = assigned_num_blocks - logger.info( - "BlockAllocatorForHeteroBlockSize is initialized, using %d - %d blocks " - "as buffer blocks for temporary remote larger block_size cache", - (total_num_blocks - assigned_num_blocks), - total_num_blocks, - ) - - heapq.heapify(self.available_block_ids) - - self.allocated_blocks: set[int] = set() - self.padding_cache: dict[int, list[int]] = {} - - def padding_block_ids( - self, block_ids: list[int], to_pad_len: int = -1 - ) -> list[int]: - if to_pad_len == 0: - return list(block_ids) - - block_ids_tuple = tuple(block_ids) - key = hash(block_ids_tuple) - - if key in self.padding_cache: - padding_blocks = self.padding_cache[key] - - return block_ids + padding_blocks - - # Check for available blocks - if self.assigned_num_blocks - len(self.allocated_blocks) < to_pad_len: - avaliable = self.assigned_num_blocks - len(self.allocated_blocks) - raise ValueError( - f"Not enough available blocks for hash {key}. " - f"Requested {to_pad_len} padding blocks, " - f"but only {avaliable} are available." - ) - - # Allocate new blocks - padding_blocks = [] - for _ in range(to_pad_len): - negative_block_id = heapq.heappop(self.available_block_ids) - new_block = -negative_block_id - - self.allocated_blocks.add(new_block) - padding_blocks.append(new_block) - - self.padding_cache[key] = padding_blocks - - return block_ids + padding_blocks - - def free_block(self, block_ids_dict: dict[float, list[list[int]]]): - block_ids_list = [i for sublist in block_ids_dict.values() for i in sublist] - for block_ids in block_ids_list: - block_ids_tuple = tuple(block_ids) - key = hash(block_ids_tuple) - if key in self.padding_cache: - padding_blocks = self.padding_cache[key] - for block_id in padding_blocks: - if block_id not in self.allocated_blocks: - continue - - self.allocated_blocks.remove(block_id) - heapq.heappush(self.available_block_ids, -block_id) - - @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" From 7fa82cb4ba4a58a6108f3edfa7641a4e0522323e Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 4 Nov 2025 20:08:58 -0800 Subject: [PATCH 15/27] move block_size_ratio to kv_topo Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 52 +++++++++++++++---- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 68f499002316..f3c0e211b64b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -662,6 +662,8 @@ class TpKVTopology: remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int + block_size: int + remote_block_size: dict[EngineId, int] def tp_ratio( self, @@ -679,6 +681,22 @@ def tp_ratio( ) return self.tp_size // remote_tp_size + def block_size_ratio( + self, + remote_block_size: int, + ) -> float: + """ + Calculate the block size ratio between local and remote TP. + """ + assert ( + self.block_size % remote_block_size == 0 + or remote_block_size % self.block_size == 0 + ), ( + f"Local block size {self.block_size} is not divisible " + f"by remote block size {remote_block_size} or vice versa." + ) + return remote_block_size / self.block_size + def tp_ratio_from_engine_id( self, remote_engine_id: EngineId, @@ -686,6 +704,13 @@ def tp_ratio_from_engine_id( remote_tp_size = self.remote_tp_size[remote_engine_id] return self.tp_ratio(remote_tp_size) + def block_size_ratio_from_engine_id( + self, + remote_engine_id: EngineId, + ) -> float: + remote_block_size = self.remote_block_size[remote_engine_id] + return self.block_size_ratio(remote_block_size) + def is_kv_replicated(self, engine_id: EngineId) -> bool: """ Whether the KV cache is replicated across TP workers due to the @@ -823,7 +848,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. self.dst_num_blocks: dict[EngineId, int] = {} - self.dst_block_size_ratio: dict[EngineId, float] = {} self._registered_descs: list[Any] = [] # In progress transfers. @@ -880,6 +904,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected kv cache layout %s", self.kv_cache_layout) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) @@ -891,6 +916,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): remote_tp_size=self._tp_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + block_size=self.block_size, + remote_block_size=self._block_size, ) def _nixl_handshake( @@ -939,12 +966,12 @@ def _nixl_handshake( ) # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, p_remote_rank, remote_tp_size - ) assert metadata.block_size <= self.block_size, ( "nP > nD is not supported yet." ) + remote_agent_name = self.add_remote_agent( + metadata, p_remote_rank, remote_tp_size + ) if metadata.block_size < self.block_size: # when prefill with small block_size, we need to init a # new handler with same block_len to match @@ -1302,6 +1329,8 @@ def add_remote_agent( ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size + if engine_id not in self._block_size: + self._block_size[engine_id] = nixl_agent_meta.block_size remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata @@ -1319,9 +1348,7 @@ def add_remote_agent( # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| - remote_block_size = nixl_agent_meta.block_size - block_size_ratio = remote_block_size / self.block_size - self.dst_block_size_ratio[engine_id] = block_size_ratio + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) # when block_size_ratio > 1, one prefill block is n decode_block # loop n times of decode block_len to match to prefill @@ -1410,6 +1437,9 @@ def _validate_remote_agent_handshake( assert nixl_agent_meta.attn_backend_name == self.backend_name tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( + remote_engine_id + ) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." @@ -1438,7 +1468,6 @@ def _validate_remote_agent_handshake( # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] - block_size_ratio = nixl_agent_meta.block_size / self.block_size if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. for i in range(len(self.block_len_per_layer)): @@ -1573,7 +1602,6 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): if len(block_ids_list) == 0: continue - assert block_size_ratio < 1, "nP > nD is not supported yet." if block_size_ratio < 1: fn = _process_local_gt_remote block_ids_list = [ @@ -1632,7 +1660,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: block_ids_to_permute += meta.local_block_ids # post processing for heteroblocksize - block_size_ratio = self.dst_block_size_ratio[meta.remote_engine_id] + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( + meta.remote_engine_id + ) if block_size_ratio not in block_ids_for_blocksize_post_process: block_ids_for_blocksize_post_process[block_size_ratio] = [] if block_size_ratio != 1 and self.kv_cache_layout == "HND": @@ -1808,7 +1838,7 @@ def _read_blocks( dst_engine_id: str, request_id: str, ): - block_size_ratio = self.dst_block_size_ratio[dst_engine_id] + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) block_size_ratio_inv = None if block_size_ratio < 1: block_size_ratio_inv = int(1 / block_size_ratio) From 402fadd04462a175271dfece11b1ad6931ff06f3 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 4 Nov 2025 20:32:20 -0800 Subject: [PATCH 16/27] use default(list) Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f3c0e211b64b..d4debb3a71cd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -972,12 +972,6 @@ def _nixl_handshake( remote_agent_name = self.add_remote_agent( metadata, p_remote_rank, remote_tp_size ) - if metadata.block_size < self.block_size: - # when prefill with small block_size, we need to init a - # new handler with same block_len to match - self.src_xfer_side_handles[metadata.block_size] = ( - self.register_local_xfer_handler(metadata.block_size) - ) setup_agent_time = time.perf_counter() logger.debug( @@ -1341,8 +1335,6 @@ def add_remote_agent( # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. - # All attn in vLLM uses blocks starts with 1st(0 is for empty) - # For hetero block size case, block 0 should always remote block_len # Example: # block_size_ratio < 1: # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| @@ -1421,6 +1413,13 @@ def add_remote_agent( remote_agent_name, descs ) + if block_size_ratio < 1: + # when prefill with small block_size, we need to init a + # new handler with same block_len to match + self.src_xfer_side_handles[nixl_agent_meta.block_size] = ( + self.register_local_xfer_handler(nixl_agent_meta.block_size) + ) + return remote_agent_name def _validate_remote_agent_handshake( @@ -1649,7 +1648,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: ) block_ids_to_permute = [] - block_ids_for_blocksize_post_process: dict[float, list[list[int]]] = {} + block_ids_for_blocksize_post_process = defaultdict(list) for req_id in done_recving: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) @@ -1663,8 +1662,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( meta.remote_engine_id ) - if block_size_ratio not in block_ids_for_blocksize_post_process: - block_ids_for_blocksize_post_process[block_size_ratio] = [] if block_size_ratio != 1 and self.kv_cache_layout == "HND": block_ids_for_blocksize_post_process[block_size_ratio].append( meta.local_block_ids From 9bf9c7f3082f2f3a31b3ae15afa54245b1609f06 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 4 Nov 2025 20:55:38 -0800 Subject: [PATCH 17/27] remove unnecessary changes to _get_block_descs_ids Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d4debb3a71cd..59e5cc6c68d5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2026,13 +2026,11 @@ def _get_block_descs_ids( num_blocks = self.dst_num_blocks[engine_id] if block_size_ratio_inv is not None: num_blocks = num_blocks * block_size_ratio_inv - num_blocks = np.full((self.num_regions), num_blocks) # Compute the desc ids for each block. + region_ids = region_ids[:, None] block_ids = np.array(block_ids)[None, :] - region_nblocks = region_ids * num_blocks - region_nblocks = region_nblocks[:, None] - descs_ids = region_nblocks + block_ids + descs_ids = region_ids * num_blocks + block_ids return descs_ids.flatten() def get_backend_aware_kv_block_len(self, layer_idx: int): From bcce3a7c0e7555b731f124d4e22514acad6996c4 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 4 Nov 2025 21:06:32 -0800 Subject: [PATCH 18/27] Clean up Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 59e5cc6c68d5..59afc7600014 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1342,18 +1342,13 @@ def add_remote_agent( # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) - # when block_size_ratio > 1, one prefill block is n decode_block - # loop n times of decode block_len to match to prefill if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - self._validate_remote_agent_handshake( - nixl_agent_meta, - remote_tp_size, - ) + self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. @@ -1377,8 +1372,6 @@ def add_remote_agent( self.tp_rank % tp_ratio * remote_kv_block_len if not replicates_kv_cache else 0 - if not replicates_kv_cache - else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] From b99156a8bcd5f7887c8da82749dbbb7404e13716 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 6 Nov 2025 22:42:56 -0800 Subject: [PATCH 19/27] Inverse block_size_ratio Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 59afc7600014..a138c160951e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -695,7 +695,8 @@ def block_size_ratio( f"Local block size {self.block_size} is not divisible " f"by remote block size {remote_block_size} or vice versa." ) - return remote_block_size / self.block_size + ret = self.block_size / remote_block_size + return ret if ret < 0 else int(ret) def tp_ratio_from_engine_id( self, @@ -1227,6 +1228,17 @@ def register_local_xfer_handler( self, block_size: int, ) -> int: + """ + Function used for register local xfer handler with local block_size or + Remote block_size. + + When local block_size is same as remote block_size, we use local block_size + to register local_xfer_handler during init. + + When remote block size is less than local block size, we need to use + register another local_xfer_handler using remote block len to ensure + data copy correctness. + """ block_size_ratio = self.block_size // block_size blocks_data = [] for i, base_addr in enumerate(self.seen_base_addresses): @@ -1336,7 +1348,7 @@ def add_remote_agent( # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. # Example: - # block_size_ratio < 1: + # block_size_ratio > 1: # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # local origin:| 0| 1| 8| 12| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| @@ -1364,8 +1376,8 @@ def add_remote_agent( # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - remote_kv_block_len = int(kv_block_len * block_size_ratio) - if block_size_ratio <= 1: + remote_kv_block_len = kv_block_len // block_size_ratio + if block_size_ratio > 1: # using remote kv_block_len as transfer unit kv_block_len = remote_kv_block_len rank_offset = ( @@ -1406,8 +1418,8 @@ def add_remote_agent( remote_agent_name, descs ) - if block_size_ratio < 1: - # when prefill with small block_size, we need to init a + if block_size_ratio > 1: + # when prefill with smaller block_size, we need to init a # new handler with same block_len to match self.src_xfer_side_handles[nixl_agent_meta.block_size] = ( self.register_local_xfer_handler(nixl_agent_meta.block_size) @@ -1464,7 +1476,7 @@ def _validate_remote_agent_handshake( # With replicated KV cache, only the number of blocks can differ. for i in range(len(self.block_len_per_layer)): assert ( - self.block_len_per_layer[i] * block_size_ratio + self.block_len_per_layer[i] // block_size_ratio == nixl_agent_meta.block_lens[i] ), "KV cache sizes must match between P and D when replicated" else: @@ -1476,7 +1488,7 @@ def _validate_remote_agent_handshake( assert ( remote_block_len - == self.block_len_per_layer[0] * tp_ratio * block_size_ratio + == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio ), ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." @@ -1568,8 +1580,8 @@ def permute_device_kv(self, block_ids: list[int]): def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]): def _process_local_gt_remote(blocks_to_update, block_size_ratio): n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] - remote_block_size = int(block_size * block_size_ratio) - n_blocks = int(1 / block_size_ratio) + remote_block_size = block_size // block_size_ratio + n_blocks = block_size_ratio # actual permute is to convert # for local blocksize > remote blocksize # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens @@ -1594,11 +1606,10 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): if len(block_ids_list) == 0: continue - if block_size_ratio < 1: - fn = _process_local_gt_remote - block_ids_list = [ - [item for sublist in block_ids_list for item in sublist] - ] + assert block_size_ratio > 1, "Only nP < nD supported currently." + fn = _process_local_gt_remote + block_ids_list = [[item for sublist in block_ids_list for item in sublist]] + for block_ids in block_ids_list: if len(block_ids) == 0: # we don't need to do permute for this req @@ -1829,11 +1840,9 @@ def _read_blocks( request_id: str, ): block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) - block_size_ratio_inv = None - if block_size_ratio < 1: - block_size_ratio_inv = int(1 / block_size_ratio) + if block_size_ratio > 1: local_block_ids = self.get_mapped_blocks( - np.asarray(local_block_ids), block_size_ratio_inv + np.asarray(local_block_ids), block_size_ratio ) if len(local_block_ids) > len(remote_block_ids): local_block_ids = local_block_ids[: len(remote_block_ids)] @@ -1881,8 +1890,8 @@ def _read_blocks( # Get side handles. block_size = ( self.block_size - if block_size_ratio >= 1 - else int(self.block_size * block_size_ratio) + if block_size_ratio <= 1 + else self.block_size // block_size_ratio ) local_xfer_side_handle = self.src_xfer_side_handles[block_size] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] @@ -1904,7 +1913,7 @@ def _read_blocks( local_block_descs_ids = self._get_block_descs_ids( self.engine_id, local_block_ids, - block_size_ratio_inv=block_size_ratio_inv, + block_size_ratio=block_size_ratio, ) else: # TODO(mgoin): remove this once we have hybrid memory allocator @@ -1933,7 +1942,7 @@ def _read_blocks( self.engine_id, layer_remote_block_ids, layer_idx, - block_size_ratio_inv=block_size_ratio_inv, + block_size_ratio=block_size_ratio, ) local_descs_list.append(layer_local_desc_ids) @@ -1994,7 +2003,7 @@ def _get_block_descs_ids( engine_id: str, block_ids: list[int], layer_idx: int | None = None, - block_size_ratio_inv: int | None = None, + block_size_ratio: float | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -2017,8 +2026,8 @@ def _get_block_descs_ids( region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] - if block_size_ratio_inv is not None: - num_blocks = num_blocks * block_size_ratio_inv + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) # Compute the desc ids for each block. region_ids = region_ids[:, None] From 4d74dd2cc7097f2b76bd9086e69b5552d1e5bee6 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 7 Nov 2025 00:03:24 -0800 Subject: [PATCH 20/27] Remove unnecessary check Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a138c160951e..12c44e7e4cf9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1604,16 +1604,11 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) sample_cache = list(self.device_kv_caches.values())[0][0] for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): - if len(block_ids_list) == 0: - continue assert block_size_ratio > 1, "Only nP < nD supported currently." fn = _process_local_gt_remote block_ids_list = [[item for sublist in block_ids_list for item in sublist]] for block_ids in block_ids_list: - if len(block_ids) == 0: - # we don't need to do permute for this req - continue indices = torch.tensor(block_ids, device=sample_cache.device) for _, cache_or_caches in self.device_kv_caches.items(): From 456a8c6faf924ac5cf31ce8ca3031605f59fe93a Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 7 Nov 2025 01:46:21 -0800 Subject: [PATCH 21/27] don't do post_process for heter_block_size for mla Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 5a7b6bffbd90..d8bdb68cd595 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1669,7 +1669,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( meta.remote_engine_id ) - if block_size_ratio != 1 and self.kv_cache_layout == "HND": + if ( + not self.use_mla + and block_size_ratio != 1 + and self.kv_cache_layout == "HND" + ): block_ids_for_blocksize_post_process[block_size_ratio].append( meta.local_block_ids ) From d0f90355e682ff9a0547e978df76e2f97f920a40 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 7 Nov 2025 10:26:44 -0800 Subject: [PATCH 22/27] make pre-commit happy Signed-off-by: Chendi Xue --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d8bdb68cd595..017312ec611d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1898,7 +1898,7 @@ def _read_blocks( block_size = ( self.block_size if block_size_ratio <= 1 - else self.block_size // block_size_ratio + else int(self.block_size // block_size_ratio) ) local_xfer_side_handle = self.src_xfer_side_handles[block_size] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] From fd21da13bced58cde5196fbcc285d25dc7ecf063 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Wed, 12 Nov 2025 10:12:06 -0600 Subject: [PATCH 23/27] Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Nicolò Lucchesi Signed-off-by: Chendi.Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 017312ec611d..c8515cb7e321 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1995,6 +1995,11 @@ def get_mapped_blocks(self, block_ids, block_size_ratio): """ Calculates the new set of block IDs by mapping every element in the (potentially sparse) input array. + Example: block_ids=[0, 2], block_size_ratio=2 + get_mapped_blocks 0 1 [2 3] 4 5 + # remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1|| + # local is |h0-b0......||h1-b0......||h2-b0........ + local_block_ids 0 [1] 2 """ if block_ids.size == 0: return np.array([], dtype=np.int64) From 8b4507af391b0186e1621b0f28a1872ae7c4d2bf Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 12 Nov 2025 08:52:11 -0800 Subject: [PATCH 24/27] Fix comments Signed-off-by: Chendi Xue --- .../nixl_integration/run_accuracy_test.sh | 4 ++-- .../kv_connector/v1/nixl_connector.py | 24 ++++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index ebc8575e5b39..87c9a105e936 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -49,8 +49,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} -PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16} -DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16} +PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128} +DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c182d45cab36..cdbfcf53407c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1657,7 +1657,6 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): sample_cache = list(self.device_kv_caches.values())[0][0] for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): assert block_size_ratio > 1, "Only nP < nD supported currently." - fn = _process_local_gt_remote block_ids_list = [[item for sublist in block_ids_list for item in sublist]] for block_ids in block_ids_list: @@ -1671,7 +1670,7 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): # virtual shape while stride can be either HND / NHD at # initialization. # we need to firstly get physical view of the tensor - permuted_blocks = fn( + permuted_blocks = _process_local_gt_remote( blocks_to_update.permute(0, 2, 1, 3), block_size_ratio ).permute(0, 2, 1, 3) cache.index_copy_(0, indices, permuted_blocks) @@ -1715,7 +1714,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: ) if ( not self.use_mla - and block_size_ratio != 1 + and block_size_ratio > 1 and self.kv_cache_layout == "HND" ): block_ids_for_blocksize_post_process[block_size_ratio].append( @@ -1896,6 +1895,17 @@ def _read_blocks( np.asarray(local_block_ids), block_size_ratio ) if len(local_block_ids) > len(remote_block_ids): + # NOTE: + # get_mapped_blocks will always expand block_ids for n times. + # ex: + # prefill block_ids with block_size as 4: + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # Local decode block_ids with block_size as 16: [1, 2, 3] + # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # Then we clip local to align with prefill + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] local_block_ids = local_block_ids[: len(remote_block_ids)] # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1939,12 +1949,8 @@ def _read_blocks( remote_block_ids = remote_block_ids[-num_local_blocks:] # Get side handles. - block_size = ( - self.block_size - if block_size_ratio <= 1 - else int(self.block_size // block_size_ratio) - ) - local_xfer_side_handle = self.src_xfer_side_handles[block_size] + remote_block_size = self.kv_topo.remote_block_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handles[remote_block_size] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from From 984637db83374896c8599abeca4f7635258f9077 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 12 Nov 2025 13:55:56 -0800 Subject: [PATCH 25/27] update script Signed-off-by: Chendi Xue --- tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh | 4 ++-- tests/v1/kv_connector/nixl_integration/test_accuracy.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 87c9a105e936..ebc8575e5b39 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -49,8 +49,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} -PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128} -DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128} +PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16} +DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index a70f4caeb937..8217321b8a89 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -52,6 +52,7 @@ def test_accuracy(): model="local-completions", model_args=model_args, tasks=TASK, + limit=256, ) measured_value = results["results"][TASK][FILTER] From 23808e5a40bc9e0bc16097df0f59511f641d8b8f Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 13 Nov 2025 07:06:31 -0800 Subject: [PATCH 26/27] fix script and comments Signed-off-by: Chendi Xue --- tests/v1/kv_connector/nixl_integration/test_accuracy.py | 1 - .../kv_transfer/kv_connector/v1/nixl_connector.py | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index 8217321b8a89..a70f4caeb937 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -52,7 +52,6 @@ def test_accuracy(): model="local-completions", model_args=model_args, tasks=TASK, - limit=256, ) measured_value = results["results"][TASK][FILTER] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index cdbfcf53407c..70962dbc4023 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -732,15 +732,11 @@ def block_size_ratio( """ Calculate the block size ratio between local and remote TP. """ - assert ( - self.block_size % remote_block_size == 0 - or remote_block_size % self.block_size == 0 - ), ( + assert self.block_size % remote_block_size == 0, ( f"Local block size {self.block_size} is not divisible " f"by remote block size {remote_block_size} or vice versa." ) - ret = self.block_size / remote_block_size - return ret if ret < 0 else int(ret) + return self.block_size // remote_block_size def tp_ratio_from_engine_id( self, From a6641c295141e18da59df418030e69c39afb91b6 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 13 Nov 2025 09:10:13 -0800 Subject: [PATCH 27/27] fix UT Signed-off-by: Chendi Xue --- tests/v1/kv_connector/unit/test_nixl_connector.py | 3 +++ .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 8e421717fea3..b7d7a10057b8 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -407,6 +407,7 @@ def _nixl_handshake( # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. kv_cache_layout="HND", + block_size=self.block_size, ), remote_tp_size=remote_tp_size, ) @@ -652,6 +653,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, + block_size=worker.block_size, ) with pytest.raises(RuntimeError): @@ -706,6 +708,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( block_lens=[i * 2 for i in worker.block_len_per_layer], attn_backend_name=worker.backend_name, kv_cache_layout="HND", + block_size=worker.block_size, ) # We don't check layout for homogeneous TP and MLA for now, as the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 70962dbc4023..8e100ac01270 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1647,6 +1647,8 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio): ) return permuted_blocks + if len(self.device_kv_caches) == 0: + return split_k_and_v = not ( self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first ) @@ -1946,7 +1948,9 @@ def _read_blocks( # Get side handles. remote_block_size = self.kv_topo.remote_block_size[dst_engine_id] - local_xfer_side_handle = self.src_xfer_side_handles[remote_block_size] + local_xfer_side_handle = self.src_xfer_side_handles.get( + remote_block_size, self.src_xfer_side_handle + ) remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from