4949from vllm .utils .network_utils import make_zmq_path , make_zmq_socket
5050from vllm .v1 .attention .backends .utils import get_kv_cache_layout
5151from vllm .v1 .core .sched .output import SchedulerOutput
52+ from vllm .v1 .worker .block_table import BlockTable
5253
5354if TYPE_CHECKING :
5455 from vllm .attention .backends .abstract import AttentionMetadata
@@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
112113@dataclass
113114class ReqMeta :
114115 local_block_ids : list [int ]
116+ # To be used when logical block size does not match the kernel block size
117+ local_physical_block_ids : list [int ]
115118 remote_block_ids : list [int ]
116119 remote_host : str
117120 remote_port : int
@@ -139,6 +142,7 @@ def add_new_req(
139142 assert load_remote_cache ^ save_to_host
140143 _req = ReqMeta (
141144 local_block_ids = local_block_ids ,
145+ local_physical_block_ids = local_block_ids ,
142146 remote_block_ids = kv_transfer_params ["remote_block_ids" ],
143147 remote_engine_id = kv_transfer_params ["remote_engine_id" ],
144148 remote_host = kv_transfer_params ["remote_host" ],
@@ -935,6 +939,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
935939 attn_backend = backend ,
936940 )
937941 self ._use_pallas = self .kv_topo ._use_pallas
942+ self ._physical_blocks_per_logical_kv_block = 1
938943
939944 def _nixl_handshake (
940945 self ,
@@ -1133,6 +1138,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11331138 if base_addr in seen_base_addresses :
11341139 continue
11351140
1141+ # TODO (NickLucche): Get kernel_block_size in a cleaner way
1142+ # NHD default "view" for non-MLA cache
1143+ kernel_block_size = cache .shape [- 2 ] if self .use_mla else cache .shape [- 3 ]
1144+
1145+ if self .block_size != kernel_block_size :
1146+ logger .info_once (
1147+ "User-specified logical block size (%s) does not match"
1148+ " physical kernel block size (%s). Using the latter. " ,
1149+ self .block_size ,
1150+ kernel_block_size ,
1151+ )
1152+ self ._physical_blocks_per_logical_kv_block = (
1153+ self .block_size // kernel_block_size
1154+ )
1155+ self .block_size = kernel_block_size
1156+
11361157 seen_base_addresses .append (base_addr )
11371158 curr_tensor_size_bytes = cache .numel () * cache .element_size ()
11381159
@@ -1479,7 +1500,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
14791500 assert self .use_host_buffer
14801501 assert self .copy_blocks is not None
14811502
1482- local_block_ids = meta .local_block_ids
1503+ local_block_ids = meta .local_physical_block_ids
14831504 self .copy_blocks (
14841505 self .host_xfer_buffers ,
14851506 self .device_kv_caches ,
@@ -1492,7 +1513,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
14921513 "synced recved kv of request[%s] to device kv buffer,"
14931514 "local_block_ids: %s. " ,
14941515 req_id ,
1495- "," .join (map (str , meta . local_block_ids )),
1516+ "," .join (map (str , local_block_ids )),
14961517 )
14971518
14981519 def save_kv_to_host (self , metadata : NixlConnectorMetadata ):
@@ -1501,19 +1522,22 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata):
15011522 assert self .copy_blocks is not None
15021523
15031524 for req_id , meta in metadata .reqs_to_save .items ():
1525+ meta .local_physical_block_ids = self ._logical_to_kernel_block_ids (
1526+ meta .local_block_ids
1527+ )
15041528 if logger .isEnabledFor (logging .DEBUG ):
15051529 logger .debug (
15061530 "save_load_kv for request[%s] to host xfer buffer."
15071531 "local_block_ids: %s. " ,
15081532 req_id ,
1509- "," .join (map (str , meta .local_block_ids )),
1533+ "," .join (map (str , meta .local_physical_block_ids )),
15101534 )
15111535 # blocking
15121536 self .copy_blocks (
15131537 self .device_kv_caches ,
15141538 self .host_xfer_buffers ,
1515- meta .local_block_ids ,
1516- meta .local_block_ids ,
1539+ meta .local_physical_block_ids ,
1540+ meta .local_physical_block_ids ,
15171541 "d2h" ,
15181542 )
15191543
@@ -1582,7 +1606,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
15821606 if self .use_host_buffer :
15831607 self .sync_recved_kv_to_device (req_id , meta )
15841608 if self .enable_permute_local_kv :
1585- block_ids_to_permute += meta .local_block_ids
1609+ block_ids_to_permute += meta .local_physical_block_ids
15861610 if len (block_ids_to_permute ) > 0 :
15871611 self .permute_device_kv (block_ids_to_permute )
15881612
@@ -1669,7 +1693,7 @@ def _pop_done_transfers(
16691693 req_id ,
16701694 xfer_state ,
16711695 )
1672- # mark all blocks for this request as invalid
1696+ # mark all (logical) blocks for this request as invalid
16731697 if meta := self ._recving_metadata .pop (req_id , None ):
16741698 self ._invalid_block_ids .update (meta .local_block_ids )
16751699 self ._recving_metadata .pop (req_id , None )
@@ -1686,13 +1710,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
16861710 We check for these trnxs to complete in each step().
16871711 """
16881712 for req_id , meta in metadata .reqs_to_recv .items ():
1713+ meta .local_physical_block_ids = self ._logical_to_kernel_block_ids (
1714+ meta .local_block_ids
1715+ )
1716+ meta .remote_block_ids = self ._logical_to_kernel_block_ids (
1717+ meta .remote_block_ids
1718+ )
16891719 remote_engine_id = meta .remote_engine_id
16901720 logger .debug (
16911721 "start_load_kv for request %s from remote engine %s. "
16921722 "Num local_block_ids: %s. Num remote_block_ids: %s. " ,
16931723 req_id ,
16941724 remote_engine_id ,
1695- len (meta .local_block_ids ),
1725+ len (meta .local_physical_block_ids ),
16961726 len (meta .remote_block_ids ),
16971727 )
16981728 # always store metadata for failure recovery
@@ -1740,7 +1770,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
17401770 self ._read_blocks (
17411771 request_id = req_id ,
17421772 dst_engine_id = meta .remote_engine_id ,
1743- local_block_ids = meta .local_block_ids ,
1773+ local_block_ids = meta .local_physical_block_ids ,
17441774 remote_block_ids = meta .remote_block_ids ,
17451775 )
17461776
@@ -1867,7 +1897,7 @@ def _read_blocks(
18671897 "Marking blocks as invalid." ,
18681898 request_id ,
18691899 )
1870- # mark all blocks for this request as invalid
1900+ # mark all (logical) blocks for this request as invalid
18711901 if meta := self ._recving_metadata .get (request_id ):
18721902 self ._invalid_block_ids .update (meta .local_block_ids )
18731903 self .xfer_stats .record_failed_transfer ()
@@ -1906,6 +1936,23 @@ def _get_block_descs_ids(
19061936 descs_ids = region_ids * num_blocks + block_ids
19071937 return descs_ids .flatten ()
19081938
1939+ def _logical_to_kernel_block_ids (self , block_ids : list [int ]) -> list [int ]:
1940+ """
1941+ Convert logical block ids to kernel physical block ids.
1942+ This is required when the logical block size (the one set by the user)
1943+ does not match the one required by the attn backend.
1944+ """
1945+ if self ._physical_blocks_per_logical_kv_block == 1 :
1946+ # Noop when physical and logical block sizes are the same
1947+ return block_ids
1948+ block_ids_np = np .array (block_ids )
1949+ block_arange = np .arange (0 , self ._physical_blocks_per_logical_kv_block ).reshape (
1950+ 1 , - 1
1951+ )
1952+ return BlockTable .map_to_kernel_blocks (
1953+ block_ids_np , self ._physical_blocks_per_logical_kv_block , block_arange
1954+ ).tolist ()
1955+
19091956 def get_backend_aware_kv_block_len (self , layer_idx : int ):
19101957 """
19111958 Get the block length for one K/V element (K and V have the same size).
0 commit comments