Skip to content

Commit 1c85cfb

Browse files
committed
fix UT
Signed-off-by: Chendi Xue <[email protected]>
1 parent 23808e5 commit 1c85cfb

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def _nixl_handshake(
407407
# `self.kv_cache_layout` is only forced to HND when vllm engine
408408
# is started. We mock HND here.
409409
kv_cache_layout="HND",
410+
block_size=self.block_size,
410411
),
411412
remote_tp_size=remote_tp_size,
412413
)
@@ -652,6 +653,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
652653
block_lens=worker.block_len_per_layer,
653654
attn_backend_name=worker.backend_name,
654655
kv_cache_layout=mismatched_layout,
656+
block_size=worker.block_size,
655657
)
656658

657659
with pytest.raises(RuntimeError):
@@ -706,6 +708,7 @@ def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
706708
block_lens=[i * 2 for i in worker.block_len_per_layer],
707709
attn_backend_name=worker.backend_name,
708710
kv_cache_layout="HND",
711+
block_size=worker.block_size,
709712
)
710713

711714
# We don't check layout for homogeneous TP and MLA for now, as the

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
886886

887887
# nixl_prepped_dlist_handle.
888888
self.src_xfer_side_handle: int = 0
889-
self.src_xfer_side_handles: dict[int, int] = {}
889+
self.src_xfer_side_handles = defaultdict[int, int](int)
890890
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
891891
self.dst_xfer_side_handles: dict[EngineId, int] = {}
892892

@@ -1647,6 +1647,8 @@ def _process_local_gt_remote(blocks_to_update, block_size_ratio):
16471647
)
16481648
return permuted_blocks
16491649

1650+
if len(self.device_kv_caches) == 0:
1651+
return
16501652
split_k_and_v = not (
16511653
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
16521654
)
@@ -1946,7 +1948,9 @@ def _read_blocks(
19461948

19471949
# Get side handles.
19481950
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
1949-
local_xfer_side_handle = self.src_xfer_side_handles[remote_block_size]
1951+
local_xfer_side_handle = (
1952+
self.src_xfer_side_handles[remote_block_size] | self.src_xfer_side_handle
1953+
)
19501954
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
19511955

19521956
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from

0 commit comments

Comments
 (0)