2020logger = init_logger (__name__ )
2121
2222StatLoggerFactory = Callable [[VllmConfig , int ], "StatLoggerBase" ]
23+ DpSharedStatLoggerFactory = Callable [[VllmConfig , Optional [list [int ]]],
24+ "PrometheusStatLogger" ]
2325
2426
2527class StatLoggerBase (ABC ):
@@ -633,37 +635,67 @@ def __init__(
633635 self ,
634636 vllm_config : VllmConfig ,
635637 engine_idxs : Optional [list [int ]] = None ,
636- custom_stat_loggers : Optional [list [StatLoggerFactory ]] = None ,
638+ custom_stat_loggers : Optional [list [Union [
639+ StatLoggerFactory , DpSharedStatLoggerFactory ]]] = None ,
637640 ):
641+ """
642+ Initializes the StatLoggerManager.
643+
644+ Args:
645+ vllm_config (VllmConfig): The configuration object for vLLM.
646+ engine_idxs (Optional[list[int]]): List of engine indices. If None,
647+ defaults to [0].
648+ custom_stat_loggers (Optional[list[Union[
649+ StatLoggerFactory, DpSharedStatLoggerFactory
650+ ]]]):
651+ Optional list of custom stat logger factories to use. If None,
652+ default loggers are used.
653+ """
638654 self .engine_idxs = engine_idxs if engine_idxs else [0 ]
655+ self .vllm_config = vllm_config
639656
640- factories : list [StatLoggerFactory ]
657+ factories : list [StatLoggerFactory ] = []
658+ shared_logger_factories : list [DpSharedStatLoggerFactory ] = []
641659 if custom_stat_loggers is not None :
642- factories = custom_stat_loggers
660+ for factory in custom_stat_loggers :
661+ if isinstance (factory , type ) and issubclass (
662+ factory , PrometheusStatLogger ):
663+ shared_logger_factories .append (factory ) # type: ignore
664+ else :
665+ factories .append (factory ) # type: ignore
643666 else :
644- factories = []
645667 if logger .isEnabledFor (logging .INFO ):
646668 factories .append (LoggingStatLogger )
647669
670+ shared_logger_factories .append (PrometheusStatLogger )
671+
672+ self .shared_loggers = []
673+ if len (shared_logger_factories ) > 0 :
674+ for factory in shared_logger_factories :
675+ self .shared_loggers .append (factory (vllm_config , engine_idxs ))
676+
648677 # engine_idx: StatLogger
649678 self .per_engine_logger_dict : dict [int , list [StatLoggerBase ]] = {}
650- prometheus_factory = PrometheusStatLogger
651679 for engine_idx in self .engine_idxs :
652680 loggers : list [StatLoggerBase ] = []
653681 for logger_factory in factories :
654- # If we get a custom prometheus logger, use that
655- # instead. This is typically used for the ray case.
656- if (isinstance (logger_factory , type )
657- and issubclass (logger_factory , PrometheusStatLogger )):
658- prometheus_factory = logger_factory
659- continue
660682 loggers .append (logger_factory (vllm_config ,
661683 engine_idx )) # type: ignore
662684 self .per_engine_logger_dict [engine_idx ] = loggers
663685
664- # For Prometheus, need to share the metrics between EngineCores.
665- # Each EngineCore's metrics are expressed as a unique label.
666- self .prometheus_logger = prometheus_factory (vllm_config , engine_idxs )
686+ def add_logger (
687+ self , logger_factory : Union [StatLoggerFactory ,
688+ DpSharedStatLoggerFactory ]
689+ ) -> None :
690+ if (isinstance (logger_factory , type )
691+ and issubclass (logger_factory , PrometheusStatLogger )):
692+ self .shared_loggers .append (
693+ logger_factory (self .vllm_config ,
694+ self .engine_idxs )) # type: ignore
695+ else :
696+ for engine_idx , logger_list in self .per_engine_logger_dict .items ():
697+ logger_list .append (logger_factory (self .vllm_config ,
698+ engine_idx )) # type: ignore
667699
668700 def record (
669701 self ,
@@ -678,17 +710,18 @@ def record(
678710 for logger in per_engine_loggers :
679711 logger .record (scheduler_stats , iteration_stats , engine_idx )
680712
681- self .prometheus_logger . record ( scheduler_stats , iteration_stats ,
682- engine_idx )
713+ for logger in self .shared_loggers :
714+ logger . record ( scheduler_stats , iteration_stats , engine_idx )
683715
684716 def log (self ):
685717 for per_engine_loggers in self .per_engine_logger_dict .values ():
686718 for logger in per_engine_loggers :
687719 logger .log ()
688720
689721 def log_engine_initialized (self ):
690- self .prometheus_logger .log_engine_initialized ()
722+ for shared_logger in self .shared_loggers :
723+ shared_logger .log_engine_initialized ()
691724
692725 for per_engine_loggers in self .per_engine_logger_dict .values ():
693- for logger in per_engine_loggers :
694- logger .log_engine_initialized ()
726+ for per_engine_logger in per_engine_loggers :
727+ per_engine_logger .log_engine_initialized ()
0 commit comments