Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 210 additions & 1 deletion tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@

from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
NixlConnectorWorker, NixlKVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput

from .utils import create_request, create_scheduler, create_vllm_config

Expand Down Expand Up @@ -475,6 +481,209 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests.
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_kv_connector_stats(dist_init):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config()

# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)

# Verify that xfer_stats starts empty
initial_stats = connector.get_kv_connector_stats()
assert initial_stats is None

# Create transfer metadata
request_id = "test_req_for_stats"
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
connector.bind_connector_metadata(metadata)

# Start the transfer
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)

# Verify stats are recorded after transfer is complete
max_iterations = 2
# Clear metadata before start_load_kv to prevent reprocessing same request
connector.bind_connector_metadata(NixlConnectorMetadata())
for _ in range(max_iterations):
# Need to call start_load_kv to process completed handshakes
connector.start_load_kv(dummy_ctx)
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0 and request_id in done_recving:
break
time.sleep(
0.1) # Small delay to allow background handshake to complete
else:
assert "Transfer did not complete within expected iterations"

# Now check that stats were recorded
stats_after_transfer = connector.get_kv_connector_stats()
assert isinstance(stats_after_transfer, NixlKVConnectorStats)

# Verify stats values are recorded
assert not stats_after_transfer.is_empty()
assert stats_after_transfer.data["num_successful_transfers"] == 1

# Verify stats are reset after retrieval
stats_after_reset = connector.get_kv_connector_stats()
assert stats_after_reset is None


def test_kv_connector_stats_aggregation():
"""
Test KV transfer stats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""

# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
aggregator = KVOutputAggregator(world_size=3)

# Create stats for multiple workers with different transfer patterns
worker1_stats = NixlKVConnectorStats()
worker2_stats = NixlKVConnectorStats()
worker3_stats = NixlKVConnectorStats()

# Record different transfers on each worker
# Worker 1: 2 transfers
worker1_stats.record_transfer()
worker1_stats.record_transfer()

# Worker 2: 1 transfer
worker2_stats.record_transfer()

# Worker 3: 3 transfers
worker3_stats.record_transfer()
worker3_stats.record_transfer()
worker3_stats.record_transfer()

# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_stats in enumerate(
[worker1_stats, worker2_stats, worker3_stats]):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2 else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0 else None, # Workers 1,2 finished receiving
kv_connector_stats=worker_stats,
))
worker_outputs.append(output)

# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
# Number of total transfers across all workers.
assert kv_connector_stats.data["num_successful_transfers"] == 6


def test_multi_kv_connector_stats_aggregation():
"""
Test MultiKVConnectorStats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""

aggregator = KVOutputAggregator(world_size=3)

from dataclasses import dataclass

@dataclass
class FooKVConnectorStats(KVConnectorStats):

def reset(self):
self.data = {"num_foo_transfers": 0}

def record_transfer(self):
if "num_foo_transfers" not in self.data:
self.data["num_foo_transfers"] = 0
self.data["num_foo_transfers"] += 1

def is_empty(self) -> bool:
return self.data["num_foo_transfers"] == 0

def aggregate(self,
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
if not other.is_empty():
self.data["num_foo_transfers"] += other.data[
"num_foo_transfers"]
return self

def make_multi_stats(nixl_count: int,
foo_count: int) -> MultiKVConnectorStats:
data: dict[str, KVConnectorStats] = {}
if nixl_count > 0:
nixl_stats = NixlKVConnectorStats()
for _ in range(nixl_count):
nixl_stats.record_transfer()
data["NixlConnector"] = nixl_stats
if foo_count > 0:
foo_stats = FooKVConnectorStats()
for _ in range(foo_count):
foo_stats.record_transfer()
data["FooConnector"] = foo_stats
return MultiKVConnectorStats(data=data)

# Create heterogeneous stats across 3 workers
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)

worker_outputs: list[ModelRunnerOutput] = []
for i, (nixl, foo) in enumerate(worker_patterns):
stats = make_multi_stats(nixl, foo)
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"]) if i < 2 else None,
finished_recving=set([f"req_{i}_recv"]) if i > 0 else None,
kv_connector_stats=stats,
),
)
worker_outputs.append(output)

aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, MultiKVConnectorStats)

# Validate per-connector totals across workers
assert kv_connector_stats["NixlConnector"].data[
"num_successful_transfers"] == 5
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6


@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
Expand Down
21 changes: 18 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, world_size: int):
def aggregate(self,
outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput:
# aggregate kv_connector_output from all workers
# Aggregate kv_connector_output from all workers

def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
Expand All @@ -142,21 +142,36 @@ def update_finished_set(req_ids: Optional[set[str]],

finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
output = output.kv_connector_output
aggregated_kv_connector_stats = None
for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output:
continue
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)

# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
assert isinstance(aggregated_kv_connector_stats,
type(kv_connector_stats))
aggregated_kv_connector_stats = \
aggregated_kv_connector_stats.aggregate(kv_connector_stats)

# select output of the worker specified by output_rank
output = outputs[output_rank]

output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
)

return output
Expand Down
22 changes: 21 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
Expand Down Expand Up @@ -235,6 +237,12 @@ def shutdown(self):
"""
return None

def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
return None

# ==============================
# Scheduler-side methods
# ==============================
Expand Down Expand Up @@ -365,4 +373,16 @@ def get_finished_count(self) -> Optional[int]:
int: expected sending or receiving completion count.
"""

return None
return None

@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return None
Loading