Skip to content

Commit 82a3a09

Browse files
committed
Signed-off-by: Seiji Eicher <[email protected]>
1 parent dda13e3 commit 82a3a09

File tree

3 files changed

+95
-53
lines changed

3 files changed

+95
-53
lines changed

tests/v1/metrics/test_engine_logger_apis.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,34 @@
33
import pytest
44

55
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
6-
from vllm.v1.metrics.loggers import PrometheusStatLogger
6+
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger
7+
8+
DEFAULT_ENGINE_ARGS = AsyncEngineArgs(
9+
model="distilbert/distilgpt2",
10+
dtype="half",
11+
disable_log_stats=False,
12+
enforce_eager=True,
13+
)
714

815

916
@pytest.mark.asyncio
10-
async def test_async_llm_add_logger():
11-
# Minimal model config for test
12-
model_name = "distilbert/distilgpt2"
13-
dtype = "half"
14-
engine_args = AsyncEngineArgs(
15-
model=model_name,
16-
dtype=dtype,
17-
disable_log_stats=False,
18-
enforce_eager=True,
19-
)
20-
21-
# Force empty list to avoid default loggers
22-
engine = AsyncLLM.from_engine_args(engine_args, stat_loggers=[])
23-
24-
# Add PrometheusStatLogger and verify no exception is raised
25-
await engine.add_logger(PrometheusStatLogger)
26-
27-
# Verify that logger is present in the first DP rank
28-
assert len(engine.stat_loggers[0]) == 1
29-
assert isinstance(engine.stat_loggers[0][0], PrometheusStatLogger)
17+
async def test_async_llm_replace_default_loggers():
18+
# Empty stat_loggers removes default loggers
19+
engine = AsyncLLM.from_engine_args(DEFAULT_ENGINE_ARGS, stat_loggers=[])
20+
await engine.add_logger(RayPrometheusStatLogger)
21+
22+
# Verify that only this logger is present in shared loggers
23+
assert len(engine.logger_manager.shared_loggers) == 1
24+
assert isinstance(engine.logger_manager.shared_loggers[0],
25+
RayPrometheusStatLogger)
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_async_llm_add_to_default_loggers():
30+
# Start with default loggers, including PrometheusStatLogger
31+
engine = AsyncLLM.from_engine_args(DEFAULT_ENGINE_ARGS)
32+
33+
# Add another PrometheusStatLogger subclass
34+
await engine.add_logger(RayPrometheusStatLogger)
35+
36+
assert len(engine.logger_manager.shared_loggers) == 2

vllm/v1/engine/async_llm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from vllm.v1.engine.parallel_sampling import ParentRequest
3737
from vllm.v1.engine.processor import Processor
3838
from vllm.v1.executor.abstract import Executor
39-
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
39+
from vllm.v1.metrics.loggers import (DpSharedStatLoggerFactory,
40+
StatLoggerFactory, StatLoggerManager)
4041
from vllm.v1.metrics.prometheus import shutdown_prometheus
4142
from vllm.v1.metrics.stats import IterationStats
4243

@@ -55,7 +56,8 @@ def __init__(
5556
use_cached_outputs: bool = False,
5657
log_requests: bool = True,
5758
start_engine_loop: bool = True,
58-
stat_loggers: Optional[list[StatLoggerFactory]] = None,
59+
stat_loggers: Optional[list[Union[StatLoggerFactory,
60+
DpSharedStatLoggerFactory]]] = None,
5961
client_addresses: Optional[dict[str, str]] = None,
6062
client_index: int = 0,
6163
) -> None:
@@ -144,7 +146,8 @@ def from_vllm_config(
144146
vllm_config: VllmConfig,
145147
start_engine_loop: bool = True,
146148
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
147-
stat_loggers: Optional[list[StatLoggerFactory]] = None,
149+
stat_loggers: Optional[list[Union[StatLoggerFactory,
150+
DpSharedStatLoggerFactory]]] = None,
148151
disable_log_requests: bool = False,
149152
disable_log_stats: bool = False,
150153
client_addresses: Optional[dict[str, str]] = None,
@@ -176,7 +179,8 @@ def from_engine_args(
176179
engine_args: AsyncEngineArgs,
177180
start_engine_loop: bool = True,
178181
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
179-
stat_loggers: Optional[list[StatLoggerFactory]] = None,
182+
stat_loggers: Optional[list[Union[StatLoggerFactory,
183+
DpSharedStatLoggerFactory]]] = None,
180184
) -> "AsyncLLM":
181185
"""Create an AsyncLLM from the EngineArgs."""
182186

@@ -596,19 +600,17 @@ async def collective_rpc(self,
596600
return await self.engine_core.collective_rpc_async(
597601
method, timeout, args, kwargs)
598602

599-
async def add_logger(self, logger_factory: StatLoggerFactory) -> None:
600-
if not self.log_stats:
603+
async def add_logger(
604+
self, logger_factory: Union[StatLoggerFactory,
605+
DpSharedStatLoggerFactory]
606+
) -> None:
607+
if self.logger_manager is None:
601608
raise RuntimeError(
602609
"Stat logging is disabled. Set `disable_log_stats=False` "
603-
"argument to enable.")
610+
"engine argument to enable.")
604611

605-
engine_num = self.vllm_config.parallel_config.data_parallel_size
606-
if len(self.stat_loggers) == 0:
607-
self.stat_loggers = [[] for _ in range(engine_num)]
612+
self.logger_manager.add_logger(logger_factory)
608613

609-
for i, logger_list in enumerate(self.stat_loggers):
610-
logger_list.append(logger_factory(self.vllm_config, i))
611-
612614
async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
613615
"""Wait for all requests to be drained."""
614616
start_time = time.time()

vllm/v1/metrics/loggers.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
logger = init_logger(__name__)
2121

2222
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
23+
DpSharedStatLoggerFactory = Callable[[VllmConfig, Optional[list[int]]],
24+
"PrometheusStatLogger"]
2325

2426

2527
class StatLoggerBase(ABC):
@@ -633,37 +635,67 @@ def __init__(
633635
self,
634636
vllm_config: VllmConfig,
635637
engine_idxs: Optional[list[int]] = None,
636-
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
638+
custom_stat_loggers: Optional[list[Union[
639+
StatLoggerFactory, DpSharedStatLoggerFactory]]] = None,
637640
):
641+
"""
642+
Initializes the StatLoggerManager.
643+
644+
Args:
645+
vllm_config (VllmConfig): The configuration object for vLLM.
646+
engine_idxs (Optional[list[int]]): List of engine indices. If None,
647+
defaults to [0].
648+
custom_stat_loggers (Optional[list[Union[
649+
StatLoggerFactory, DpSharedStatLoggerFactory
650+
]]]):
651+
Optional list of custom stat logger factories to use. If None,
652+
default loggers are used.
653+
"""
638654
self.engine_idxs = engine_idxs if engine_idxs else [0]
655+
self.vllm_config = vllm_config
639656

640-
factories: list[StatLoggerFactory]
657+
factories: list[StatLoggerFactory] = []
658+
shared_logger_factories: list[DpSharedStatLoggerFactory] = []
641659
if custom_stat_loggers is not None:
642-
factories = custom_stat_loggers
660+
for factory in custom_stat_loggers:
661+
if isinstance(factory, type) and issubclass(
662+
factory, PrometheusStatLogger):
663+
shared_logger_factories.append(factory) # type: ignore
664+
else:
665+
factories.append(factory) # type: ignore
643666
else:
644-
factories = []
645667
if logger.isEnabledFor(logging.INFO):
646668
factories.append(LoggingStatLogger)
647669

670+
shared_logger_factories.append(PrometheusStatLogger)
671+
672+
self.shared_loggers = []
673+
if len(shared_logger_factories) > 0:
674+
for factory in shared_logger_factories:
675+
self.shared_loggers.append(factory(vllm_config, engine_idxs))
676+
648677
# engine_idx: StatLogger
649678
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
650-
prometheus_factory = PrometheusStatLogger
651679
for engine_idx in self.engine_idxs:
652680
loggers: list[StatLoggerBase] = []
653681
for logger_factory in factories:
654-
# If we get a custom prometheus logger, use that
655-
# instead. This is typically used for the ray case.
656-
if (isinstance(logger_factory, type)
657-
and issubclass(logger_factory, PrometheusStatLogger)):
658-
prometheus_factory = logger_factory
659-
continue
660682
loggers.append(logger_factory(vllm_config,
661683
engine_idx)) # type: ignore
662684
self.per_engine_logger_dict[engine_idx] = loggers
663685

664-
# For Prometheus, need to share the metrics between EngineCores.
665-
# Each EngineCore's metrics are expressed as a unique label.
666-
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
686+
def add_logger(
687+
self, logger_factory: Union[StatLoggerFactory,
688+
DpSharedStatLoggerFactory]
689+
) -> None:
690+
if (isinstance(logger_factory, type)
691+
and issubclass(logger_factory, PrometheusStatLogger)):
692+
self.shared_loggers.append(
693+
logger_factory(self.vllm_config,
694+
self.engine_idxs)) # type: ignore
695+
else:
696+
for engine_idx, logger_list in self.per_engine_logger_dict.items():
697+
logger_list.append(logger_factory(self.vllm_config,
698+
engine_idx)) # type: ignore
667699

668700
def record(
669701
self,
@@ -678,17 +710,18 @@ def record(
678710
for logger in per_engine_loggers:
679711
logger.record(scheduler_stats, iteration_stats, engine_idx)
680712

681-
self.prometheus_logger.record(scheduler_stats, iteration_stats,
682-
engine_idx)
713+
for logger in self.shared_loggers:
714+
logger.record(scheduler_stats, iteration_stats, engine_idx)
683715

684716
def log(self):
685717
for per_engine_loggers in self.per_engine_logger_dict.values():
686718
for logger in per_engine_loggers:
687719
logger.log()
688720

689721
def log_engine_initialized(self):
690-
self.prometheus_logger.log_engine_initialized()
722+
for shared_logger in self.shared_loggers:
723+
shared_logger.log_engine_initialized()
691724

692725
for per_engine_loggers in self.per_engine_logger_dict.values():
693-
for logger in per_engine_loggers:
694-
logger.log_engine_initialized()
726+
for per_engine_logger in per_engine_loggers:
727+
per_engine_logger.log_engine_initialized()

0 commit comments

Comments
 (0)