Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def build_kv_connector_stats(
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
Expand Down
8 changes: 6 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from prometheus_client import Counter, Gauge, Histogram

from vllm.config.kv_transfer import KVTransferConfig
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
from vllm.logger import init_logger
Expand Down Expand Up @@ -117,10 +117,12 @@ class KVConnectorPromMetrics:

def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self._kv_transfer_config = vllm_config.kv_transfer_config
self._gauge_cls = metric_types[Gauge]
self._counter_cls = metric_types[Counter]
self._histogram_cls = metric_types[Histogram]
Expand Down Expand Up @@ -161,11 +163,12 @@ class KVConnectorPrometheus:

def __init__(
self,
kv_transfer_config: KVTransferConfig | None,
vllm_config: VllmConfig,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.prom_metrics: KVConnectorPromMetrics | None = None
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config and kv_transfer_config.kv_connector:
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
metric_types = {
Expand All @@ -174,6 +177,7 @@ def __init__(
Histogram: self._histogram_cls,
}
self.prom_metrics = connector_cls.build_prom_metrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
Expand Down
110 changes: 88 additions & 22 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@

from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
Expand Down Expand Up @@ -72,6 +78,27 @@ def __setitem__(self, connector_id: str, stats: KVConnectorStats):
self.data[connector_id] = stats


class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: "VllmConfig",
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
prom_metrics: dict[str, KVConnectorPromMetrics],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
self._prom_metrics = prom_metrics

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for connector_id, stats_data in transfer_stats_data.items():
assert connector_id in self._prom_metrics, (
f"{connector_id} is not contained in the list of registered connectors "
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
)
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)


class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
Expand All @@ -84,19 +111,13 @@ class MultiConnector(KVConnectorBase_V1):

def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)

self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
self._connectors.append(
KVConnectorFactory.create_connector(temp_config, role)
)
for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config
):
self._connectors.append(connector_cls(temp_config, role))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)

# A mapping from request id to the index of the connector chosen to
Expand All @@ -109,6 +130,32 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}

@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
ret.append(
(
KVConnectorFactory.get_connector_class(
temp_config.kv_transfer_config
),
temp_config,
)
)
return ret

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
Expand Down Expand Up @@ -295,18 +342,12 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
layouts: set[str] = set()
temp_vllm_config = copy.copy(vllm_config)
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
temp_vllm_config
temp_config
)
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)
Expand Down Expand Up @@ -342,3 +383,28 @@ def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector

@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
connector_prom = connector_cls.build_prom_metrics(
temp_config, metric_types, labelnames, per_engine_labelvalues
)
if connector_prom is not None:
prom_metrics[connector_cls.__name__] = connector_prom
return MultiKVConnectorPromMetrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
prom_metrics,
)
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,14 @@ def build_kv_connector_stats(
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
return NixlPromMetrics(metric_types, labelnames, per_engine_labelvalues)
return NixlPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
Expand Down Expand Up @@ -1763,11 +1766,12 @@ def num_successful_transfers(self) -> int:
class NixlPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
super().__init__(metric_types, labelnames, per_engine_labelvalues)
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)

buckets = [
0.001,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def __init__(
vllm_config.speculative_config, labelnames, per_engine_labelvalues
)
self.kv_connector_prom = self._kv_connector_cls(
vllm_config.kv_transfer_config, labelnames, per_engine_labelvalues
vllm_config, labelnames, per_engine_labelvalues
)

#
Expand Down