Skip to content

Commit a6641c2

Browse files
committed
fix UT
Signed-off-by: Chendi Xue <[email protected]>
1 parent 23808e5 commit a6641c2

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def _nixl_handshake(
407407
# `self.kv_cache_layout` is only forced to HND when vllm engine
408408
# is started. We mock HND here.
409409
kv_cache_layout="HND",
410+
block_size=self.block_size,
410411
),
411412
remote_tp_size=remote_tp_size,
412413
)
@@ -652,6 +653,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
652653
block_lens=worker.block_len_per_layer,
653654
attn_backend_name=worker.backend_name,
654655
kv_cache_layout=mismatched_layout,
656+
block_size=worker.block_size,
655657
)
656658

657659
with pytest.raises(RuntimeError):
@@ -706,6 +708,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
706708
block_lens=[i * 2 for i in worker.block_len_per_layer],
707709
attn_backend_name=worker.backend_name,
708710
kv_cache_layout="HND",
711+
block_size=worker.block_size,
709712
)
710713

711714
# We don't check layout for homogeneous TP and MLA for now, as the

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,8 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio):
16471647
)
16481648
return permuted_blocks
16491649

1650+
if len(self.device_kv_caches) == 0:
1651+
return
16501652
split_k_and_v = not (
16511653
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
16521654
)
@@ -1946,7 +1948,9 @@ def _read_blocks(
19461948

19471949
# Get side handles.
19481950
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
1949-
local_xfer_side_handle = self.src_xfer_side_handles[remote_block_size]
1951+
local_xfer_side_handle = self.src_xfer_side_handles.get(
1952+
remote_block_size, self.src_xfer_side_handle
1953+
)
19501954
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
19511955

19521956
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from

0 commit comments

Comments
 (0)