Skip to content

Commit 211049b

Browse files
NickLucchedevpatelio
authored andcommitted
[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (vllm-project#28677)
Signed-off-by: NickLucche <[email protected]>
1 parent 69c76b7 commit 211049b

File tree

3 files changed

+73
-17
lines changed

3 files changed

+73
-17
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
985985
req_index = 0
986986
block_table.append_row(kvcache_manager_blocks, req_index)
987987
# Get expected kernel blocks from the implementation for verification.
988-
expected_kernel_blocks = block_table._map_to_kernel_blocks(
989-
np.array(kvcache_manager_blocks)
988+
expected_kernel_blocks = block_table.map_to_kernel_blocks(
989+
np.array(kvcache_manager_blocks),
990+
block_table.blocks_per_kv_block,
991+
block_table._kernel_block_arange,
990992
)
991993
# Verify block table state
992994
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
5050
from vllm.v1.attention.backends.utils import get_kv_cache_layout
5151
from vllm.v1.core.sched.output import SchedulerOutput
52+
from vllm.v1.worker.block_table import BlockTable
5253

5354
if TYPE_CHECKING:
5455
from vllm.attention.backends.abstract import AttentionMetadata
@@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
112113
@dataclass
113114
class ReqMeta:
114115
local_block_ids: list[int]
116+
# To be used when logical block size does not match the kernel block size
117+
local_physical_block_ids: list[int]
115118
remote_block_ids: list[int]
116119
remote_host: str
117120
remote_port: int
@@ -139,6 +142,7 @@ def add_new_req(
139142
assert load_remote_cache ^ save_to_host
140143
_req = ReqMeta(
141144
local_block_ids=local_block_ids,
145+
local_physical_block_ids=local_block_ids,
142146
remote_block_ids=kv_transfer_params["remote_block_ids"],
143147
remote_engine_id=kv_transfer_params["remote_engine_id"],
144148
remote_host=kv_transfer_params["remote_host"],
@@ -935,6 +939,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
935939
attn_backend=backend,
936940
)
937941
self._use_pallas = self.kv_topo._use_pallas
942+
self._physical_blocks_per_logical_kv_block = 1
938943

939944
def _nixl_handshake(
940945
self,
@@ -1133,6 +1138,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11331138
if base_addr in seen_base_addresses:
11341139
continue
11351140

1141+
# TODO (NickLucche): Get kernel_block_size in a cleaner way
1142+
# NHD default "view" for non-MLA cache
1143+
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
1144+
1145+
if self.block_size != kernel_block_size:
1146+
logger.info_once(
1147+
"User-specified logical block size (%s) does not match"
1148+
" physical kernel block size (%s). Using the latter. ",
1149+
self.block_size,
1150+
kernel_block_size,
1151+
)
1152+
self._physical_blocks_per_logical_kv_block = (
1153+
self.block_size // kernel_block_size
1154+
)
1155+
self.block_size = kernel_block_size
1156+
11361157
seen_base_addresses.append(base_addr)
11371158
curr_tensor_size_bytes = cache.numel() * cache.element_size()
11381159

@@ -1479,7 +1500,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
14791500
assert self.use_host_buffer
14801501
assert self.copy_blocks is not None
14811502

1482-
local_block_ids = meta.local_block_ids
1503+
local_block_ids = meta.local_physical_block_ids
14831504
self.copy_blocks(
14841505
self.host_xfer_buffers,
14851506
self.device_kv_caches,
@@ -1492,7 +1513,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
14921513
"synced recved kv of request[%s] to device kv buffer,"
14931514
"local_block_ids: %s. ",
14941515
req_id,
1495-
",".join(map(str, meta.local_block_ids)),
1516+
",".join(map(str, local_block_ids)),
14961517
)
14971518

14981519
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
@@ -1501,19 +1522,22 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata):
15011522
assert self.copy_blocks is not None
15021523

15031524
for req_id, meta in metadata.reqs_to_save.items():
1525+
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
1526+
meta.local_block_ids
1527+
)
15041528
if logger.isEnabledFor(logging.DEBUG):
15051529
logger.debug(
15061530
"save_load_kv for request[%s] to host xfer buffer."
15071531
"local_block_ids: %s. ",
15081532
req_id,
1509-
",".join(map(str, meta.local_block_ids)),
1533+
",".join(map(str, meta.local_physical_block_ids)),
15101534
)
15111535
# blocking
15121536
self.copy_blocks(
15131537
self.device_kv_caches,
15141538
self.host_xfer_buffers,
1515-
meta.local_block_ids,
1516-
meta.local_block_ids,
1539+
meta.local_physical_block_ids,
1540+
meta.local_physical_block_ids,
15171541
"d2h",
15181542
)
15191543

@@ -1582,7 +1606,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
15821606
if self.use_host_buffer:
15831607
self.sync_recved_kv_to_device(req_id, meta)
15841608
if self.enable_permute_local_kv:
1585-
block_ids_to_permute += meta.local_block_ids
1609+
block_ids_to_permute += meta.local_physical_block_ids
15861610
if len(block_ids_to_permute) > 0:
15871611
self.permute_device_kv(block_ids_to_permute)
15881612

@@ -1669,7 +1693,7 @@ def _pop_done_transfers(
16691693
req_id,
16701694
xfer_state,
16711695
)
1672-
# mark all blocks for this request as invalid
1696+
# mark all (logical)blocks for this request as invalid
16731697
if meta := self._recving_metadata.pop(req_id, None):
16741698
self._invalid_block_ids.update(meta.local_block_ids)
16751699
self._recving_metadata.pop(req_id, None)
@@ -1686,13 +1710,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
16861710
We check for these trnxs to complete in each step().
16871711
"""
16881712
for req_id, meta in metadata.reqs_to_recv.items():
1713+
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
1714+
meta.local_block_ids
1715+
)
1716+
meta.remote_block_ids = self._logical_to_kernel_block_ids(
1717+
meta.remote_block_ids
1718+
)
16891719
remote_engine_id = meta.remote_engine_id
16901720
logger.debug(
16911721
"start_load_kv for request %s from remote engine %s. "
16921722
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
16931723
req_id,
16941724
remote_engine_id,
1695-
len(meta.local_block_ids),
1725+
len(meta.local_physical_block_ids),
16961726
len(meta.remote_block_ids),
16971727
)
16981728
# always store metadata for failure recovery
@@ -1740,7 +1770,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
17401770
self._read_blocks(
17411771
request_id=req_id,
17421772
dst_engine_id=meta.remote_engine_id,
1743-
local_block_ids=meta.local_block_ids,
1773+
local_block_ids=meta.local_physical_block_ids,
17441774
remote_block_ids=meta.remote_block_ids,
17451775
)
17461776

@@ -1867,7 +1897,7 @@ def _read_blocks(
18671897
"Marking blocks as invalid.",
18681898
request_id,
18691899
)
1870-
# mark all blocks for this request as invalid
1900+
# mark all (logical) blocks for this request as invalid
18711901
if meta := self._recving_metadata.get(request_id):
18721902
self._invalid_block_ids.update(meta.local_block_ids)
18731903
self.xfer_stats.record_failed_transfer()
@@ -1906,6 +1936,23 @@ def _get_block_descs_ids(
19061936
descs_ids = region_ids * num_blocks + block_ids
19071937
return descs_ids.flatten()
19081938

1939+
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
1940+
"""
1941+
Convert logical block ids to kernel physical block ids.
1942+
This is required when the logical block size (the one set by the user)
1943+
does not match the one required by the attn backend.
1944+
"""
1945+
if self._physical_blocks_per_logical_kv_block == 1:
1946+
# Noop when physical and logical block sizes are the same
1947+
return block_ids
1948+
block_ids_np = np.array(block_ids)
1949+
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1950+
1, -1
1951+
)
1952+
return BlockTable.map_to_kernel_blocks(
1953+
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
1954+
).tolist()
1955+
19091956
def get_backend_aware_kv_block_len(self, layer_idx: int):
19101957
"""
19111958
Get the block length for one K/V element (K and V have the same size).

vllm/v1/worker/block_table.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def append_row(
9898
return
9999

100100
if self.use_hybrid_blocks:
101-
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
101+
block_ids = self.map_to_kernel_blocks(
102+
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
103+
)
102104

103105
num_blocks = len(block_ids)
104106
start = self.num_blocks_per_row[row_idx]
@@ -188,7 +190,12 @@ def clear(self) -> None:
188190
self.block_table.gpu.fill_(0)
189191
self.block_table.cpu.fill_(0)
190192

191-
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
193+
@staticmethod
194+
def map_to_kernel_blocks(
195+
kv_manager_block_ids: np.ndarray,
196+
blocks_per_kv_block: int,
197+
kernel_block_arange: np.ndarray,
198+
) -> np.ndarray:
192199
"""Convert kv_manager_block_id IDs to kernel block IDs.
193200
194201
Example:
@@ -203,12 +210,12 @@ def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
203210
# kv_manager_block_id 1 → kernel block id [2, 3]
204211
# kv_manager_block_id 2 → kernel block id [4, 5]
205212
"""
206-
if not self.use_hybrid_blocks:
213+
if blocks_per_kv_block == 1:
207214
return kv_manager_block_ids
208215

209216
kernel_block_ids = (
210-
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
211-
+ self._kernel_block_arange
217+
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
218+
+ kernel_block_arange
212219
)
213220

214221
return kernel_block_ids.reshape(-1)

0 commit comments

Comments
 (0)