From 491f0aa92393b9b66e36fa4c95c83ca6ed6b8d32 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 13 Jun 2025 10:02:09 -0700 Subject: [PATCH 01/12] [DP] Support external DP Load Balancer mode wip Signed-off-by: Nick Hill --- vllm/config.py | 13 +++ vllm/engine/arg_utils.py | 28 +++-- vllm/entrypoints/cli/serve.py | 107 +++++++++++------ vllm/v1/engine/coordinator.py | 46 ++++++-- vllm/v1/engine/core.py | 128 ++++++++++++++++----- vllm/v1/engine/core_client.py | 208 +++++++++++++++++++++------------- vllm/v1/utils.py | 36 +++++- 7 files changed, 402 insertions(+), 164 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 96ea47a0dce3..e72478b20a1f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1773,6 +1773,10 @@ class ParallelConfig: """Port of the data parallel master.""" data_parallel_backend: str = "mp" """Backend to use for data parallel, either "mp" or "ray".""" + data_parallel_external_lb: bool = False + """Whether to use "external" DP LB mode. Applies only to online serving + and when data_parallel_size > 0. Set implicitly when + data_parallel_rank is provided explicitly to vllm serve.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" max_parallel_loading_workers: Optional[int] = None @@ -1900,6 +1904,11 @@ def __post_init__(self) -> None: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() + + if not (0 <= self.data_parallel_rank < self.data_parallel_size): + raise ValueError( + f"data_parallel_rank ({self.data_parallel_rank})" + f" must be in the range [0, {self.data_parallel_size})") else: # Otherwise fall back to env vars (e.g. for offline SPMD case). self.data_parallel_size = envs.VLLM_DP_SIZE @@ -1908,6 +1917,10 @@ def __post_init__(self) -> None: self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + if self.data_parallel_external_lb: + raise ValueError("data_parallel_external_lb can only " + "be set when data_parallel_size > 1") + if self.distributed_executor_backend == "external_launcher": import os os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9d1008b6b350..6d0b20117d4a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -315,6 +315,7 @@ class EngineArgs: pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size + data_parallel_rank: Optional[int] = None data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None @@ -642,6 +643,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["tensor_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument( + '--data-parallel-rank', + '-dpn', + type=int, + help='Data parallel rank of this instance. ' + 'When set, enables external load balancer mode.') parallel_group.add_argument('--data-parallel-size-local', '-dpl', type=int, @@ -1096,10 +1103,17 @@ def create_engine_config( # but we should not do this here. placement_group = ray.util.get_current_placement_group() - # Local DP size defaults to global DP size if not set. - data_parallel_size_local = self.data_parallel_size if ( - self.data_parallel_size_local - is None) else self.data_parallel_size_local + data_parallel_external_lb = self.data_parallel_rank is not None + if data_parallel_external_lb: + assert self.data_parallel_size_local in (1, None), ( + "data_parallel_size_local must be 1 when data_parallel_rank " + "is set") + data_parallel_size_local = 1 + elif self.data_parallel_size_local is not None: + data_parallel_size_local = self.data_parallel_size_local + else: + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = self.data_parallel_size # DP address, used in multi-node case for torch distributed group # and ZMQ sockets. @@ -1124,16 +1138,16 @@ def create_engine_config( self.data_parallel_rpc_port is not None) else ParallelConfig.data_parallel_rpc_port - data_parallel_backend = self.data_parallel_backend - parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, + data_parallel_rank=self.data_parallel_rank or 0, + data_parallel_external_lb=data_parallel_external_lb, data_parallel_size_local=data_parallel_size_local, data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, - data_parallel_backend=data_parallel_backend, + data_parallel_backend=self.data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 897c222a3ff5..35b3942d506e 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -5,6 +5,7 @@ import os import signal import sys +from typing import Optional import uvloop import zmq @@ -21,7 +22,8 @@ from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx +from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, + get_tcp_uri, zmq_socket_ctx) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager @@ -148,7 +150,7 @@ def signal_handler(signum, frame): start_index=args.data_parallel_start_rank, local_start_index=0, vllm_config=vllm_config, - on_head_node=False, + local_client=False, handshake_address=handshake_address, executor_class=Executor.get_class(vllm_config), log_stats=not engine_args.disable_log_stats, @@ -194,20 +196,29 @@ def run_multi_api_server(args: argparse.Namespace): parallel_config = vllm_config.parallel_config - assert parallel_config.data_parallel_rank == 0 - dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip - local_only = local_engine_count == dp_size + external_dp_lb = parallel_config.data_parallel_external_lb + + # The handshake socket can be IPC if this front-end is colocated + # with all engines. + handshake_local_only = local_engine_count == dp_size + + # client_local_only = True for cases where this front-end + # sends requests only to colocated engines. + client_local_only = handshake_local_only or external_dp_lb + + assert external_dp_lb or dp_rank == 0 # Set up input and output addresses. input_addresses = [ - get_engine_client_zmq_addr(local_only, host) + get_engine_client_zmq_addr(client_local_only, host) for _ in range(num_api_servers) ] output_addresses = [ - get_engine_client_zmq_addr(local_only, host) + get_engine_client_zmq_addr(client_local_only, host) for _ in range(num_api_servers) ] @@ -217,16 +228,32 @@ def run_multi_api_server(args: argparse.Namespace): ) # Set up coordinator for dp > 1. + # This runs only with the rank 0 engine, or when not colocated with any + # engines (in which case the value of dp_rank here will also be 0). coordinator = None stats_update_address = None - if dp_size > 1: + if dp_size > 1 and dp_rank == 0: coordinator = DPCoordinator(parallel_config) addresses.coordinator_input, addresses.coordinator_output = ( coordinator.get_engine_socket_addresses()) + addresses.frontend_stats_publish_address = ( + coordinator.get_stats_publish_address()) stats_update_address = coordinator.get_stats_publish_address() logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid) + api_server_manager: Optional[APIServerProcessManager] = None + # Construct common args for the APIServerProcessManager up-front. + api_server_manager_kwargs = dict( + target_server_fn=run_api_server_worker_proc, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) + if parallel_config.data_parallel_backend == "ray": logger.info("Starting ray-based data parallel backend") @@ -238,14 +265,7 @@ def run_multi_api_server(args: argparse.Namespace): ) # Start API servers using the manager api_server_manager = APIServerProcessManager( - target_server_fn=run_api_server_worker_proc, - listen_address=listen_address, - sock=sock, - args=args, - num_servers=num_api_servers, - input_addresses=input_addresses, - output_addresses=output_addresses, - stats_update_address=stats_update_address) + **api_server_manager_kwargs) wait_for_completion_or_failure(api_server_manager=api_server_manager, engine_manager=engine_actor_manager, @@ -253,9 +273,17 @@ def run_multi_api_server(args: argparse.Namespace): return handshake_address = get_engine_client_zmq_addr( - local_only, host, parallel_config.data_parallel_rpc_port) + handshake_local_only, host, parallel_config.data_parallel_rpc_port) + + if external_dp_lb and dp_rank > 0: + assert not handshake_local_only + local_handshake_address = get_open_zmq_ipc_path() + client_handshake_address = local_handshake_address + else: + local_handshake_address = handshake_address + client_handshake_address = None - with zmq_socket_ctx(handshake_address, zmq.ROUTER, + with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, bind=True) as handshake_socket: # Start local engines. @@ -268,27 +296,31 @@ def run_multi_api_server(args: argparse.Namespace): executor_class=Executor.get_class(vllm_config), log_stats=not engine_args.disable_log_stats, handshake_address=handshake_address, - on_head_node=True, + client_handshake_address=client_handshake_address, + local_client=True, local_engine_count=local_engine_count, - start_index=0, + start_index=dp_rank, local_start_index=0) - # Start API servers using the manager - api_server_manager = APIServerProcessManager( - target_server_fn=run_api_server_worker_proc, - listen_address=listen_address, - sock=sock, - args=args, - num_servers=num_api_servers, - input_addresses=input_addresses, - output_addresses=output_addresses, - stats_update_address=stats_update_address) + # For dp ranks > 0 in external DP LB mode, we must delay the + # start of the API servers until the local engine is started, + # since we get the front-end stats update address from the coordinator + # via the handshake with the local engine. + if dp_rank == 0 or not external_dp_lb: + # Start API servers using the manager. + api_server_manager = APIServerProcessManager( + **api_server_manager_kwargs) # Wait for engine handshakes to complete. - core_engines = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) - ] + if external_dp_lb and dp_rank > 0: + assert local_engine_count == 1 + core_engines = [CoreEngine(index=dp_rank, local=True)] + else: + core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] + wait_for_engine_startup( handshake_socket, addresses, @@ -299,6 +331,13 @@ def run_multi_api_server(args: argparse.Namespace): coordinator.proc if coordinator else None, ) + # Start API servers now if they weren't already started. + if api_server_manager is None: + api_server_manager_kwargs["stats_update_address"] = ( + addresses.frontend_stats_publish_address) + api_server_manager = APIServerProcessManager( + **api_server_manager_kwargs) + # Wait for API servers wait_for_completion_or_failure(api_server_manager=api_server_manager, engine_manager=local_engine_manager, diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 031e9b85f24c..3a2fa00e6ae6 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -10,7 +10,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket +from vllm.utils import get_mp_context, make_zmq_socket from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType from vllm.v1.serial_utils import MsgpackDecoder from vllm.v1.utils import get_engine_client_zmq_addr, shutdown @@ -48,20 +48,33 @@ class DPCoordinator: Engines will move into running state when receiving a new request or START_DP_WAVE message. + + Note that when deployed in External LB mode, no stats will be published by + the engines and thus updates will only be sent to front-ends when the + request wave / running state changes. """ def __init__(self, parallel_config: ParallelConfig): - # Assume coordinator is colocated with front-end procs. - front_publish_address = get_open_zmq_ipc_path() - dp_size = parallel_config.data_parallel_size assert dp_size > 1, "Coordinator only used for data parallel" - local_only = dp_size == parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip - back_publish_address = get_engine_client_zmq_addr(local_only, host) - back_output_address = get_engine_client_zmq_addr(local_only, host) + external_lb = parallel_config.data_parallel_external_lb + + # Assume coordinator is colocated with front-end procs when not in + # external DP LB mode. + front_publish_address = get_engine_client_zmq_addr( + local_only=not external_lb, host=host) + + local_only_eng = dp_size == parallel_config.data_parallel_size_local + back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) + back_output_address = get_engine_client_zmq_addr(local_only_eng, host) + + # When in external LB mode, load stats aren't published, only changes + # to request wave / running state, so we don't need to rate-limit the + # updates to the front-end proc(s). + min_stats_update_interval_ms = 0 if external_lb else 100 context = get_mp_context() self.proc: multiprocessing.Process = context.Process( @@ -72,6 +85,7 @@ def __init__(self, parallel_config: ParallelConfig): "front_publish_address": front_publish_address, "back_output_address": back_output_address, "back_publish_address": back_publish_address, + "min_stats_update_interval_ms": min_stats_update_interval_ms, }, daemon=True) self.proc.start() @@ -100,12 +114,16 @@ def __init__(self): class CoordinatorProc: - def __init__(self, engine_count: int): + def __init__(self, + engine_count: int, + min_stats_update_interval_ms: int = 100): self.ctx = zmq.Context() self.engines = [EngineState() for _ in range(engine_count)] + self.stats_update_interval_ms = min_stats_update_interval_ms + self.current_wave = 0 self.engines_running = False self.stats_changed = False @@ -116,8 +134,11 @@ def run_coordinator( front_publish_address: str, back_output_address: str, back_publish_address: str, + min_stats_update_interval_ms: int = 100, ): - coordinator = CoordinatorProc(engine_count=engine_count) + coordinator = CoordinatorProc( + engine_count=engine_count, + min_stats_update_interval_ms=min_stats_update_interval_ms) try: coordinator.process_input_socket( front_publish_address, @@ -156,9 +177,10 @@ def process_input_socket(self, front_publish_address: str, last_publish_time = 0 while True: elapsed = int(time.time() * 1000) - last_publish_time - # Send at 100 ms interval if the stats have changed, - # or otherwise every 3 seconds. - wait_for = 100 if self.stats_changed else 3000 + # Send at stats_update_interval_ms interval if the stats have + # changed, or otherwise every 4 seconds. + wait_for = (self.stats_update_interval_ms + if self.stats_changed else 4000) events = poller.poll(timeout=max(0, wait_for - elapsed)) if not events: # Poller timeout - publish current stats to front-ends. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 453ed364dc81..4a468680793e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -367,10 +367,11 @@ class EngineCoreProc(EngineCore): def __init__( self, vllm_config: VllmConfig, - on_head_node: bool, + local_client: bool, handshake_address: str, executor_class: type[Executor], log_stats: bool, + client_handshake_address: Optional[str] = None, engine_index: int = 0, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() @@ -383,12 +384,21 @@ def __init__( identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False - with self._perform_handshake(handshake_address, identity, on_head_node, - vllm_config) as addresses: + with self._perform_handshakes(handshake_address, identity, + local_client, vllm_config, + client_handshake_address) as addresses: self.client_count = len(addresses.outputs) # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None + self.frontend_stats_publish_address = ( + addresses.frontend_stats_publish_address) + # Only publish request queue stats to coordinator for "internal" + # LB mode. + self.publish_dp_lb_stats = ( + self.has_coordinator + and not vllm_config.parallel_config.data_parallel_external_lb) + self._init_data_parallel(vllm_config) super().__init__(vllm_config, executor_class, log_stats, @@ -414,45 +424,102 @@ def __init__( self.output_thread.start() @contextmanager - def _perform_handshake( - self, handshake_address: str, identity: bytes, on_head_node: bool, - vllm_config: VllmConfig + def _perform_handshakes( + self, + handshake_address: str, + identity: bytes, + local_client: bool, + vllm_config: VllmConfig, + client_handshake_address: Optional[str], ) -> Generator[EngineZmqAddresses, None, None]: + """ + Perform startup handshakes. + + For DP=1 or offline mode, this is with the colocated front-end process. + + For DP>1 with internal loadbalancing this is with the shared front-end + process which may reside on a different node. + + For DP>1 with external loadbalancing, two handshakes are performed: + - With the rank 0 front-end process which retrieves the + DP Coordinator ZMQ addresses and DP process group address. + - With the colocated front-end process which retrieves the + client input/output socket addresses. + with the exception of the rank 0 engine itself which doesn't require + the second handshake. + + Here, "front-end" process can mean the process containing the engine + core client (which is the API server process in the case the API + server is not scaled out), OR the launcher process running the + run_multi_api_server() function in serve.py. + """ input_ctx = zmq.Context() - with make_zmq_socket(input_ctx, + is_local = local_client and client_handshake_address is None + handshake = self._perform_handshake(input_ctx, handshake_address, + identity, is_local, vllm_config, + vllm_config.parallel_config) + if client_handshake_address is None: + with handshake as addresses: + yield addresses + else: + local_handshake = self._perform_handshake( + input_ctx, client_handshake_address, identity, local_client, + vllm_config) + with handshake as addresses, local_handshake as client_addresses: + addresses.inputs = client_addresses.inputs + addresses.outputs = client_addresses.outputs + yield addresses + + # Update config which may have changed from the handshake + vllm_config.__post_init__() + + @contextmanager + def _perform_handshake( + self, + ctx: zmq.Context, + handshake_address: str, + identity: bytes, + local_client: bool, + vllm_config: VllmConfig, + parallel_config_to_update: Optional[ParallelConfig] = None, + ) -> Generator[EngineZmqAddresses, None, None]: + with make_zmq_socket(ctx, handshake_address, zmq.DEALER, identity=identity, linger=5000, bind=False) as handshake_socket: # Register engine with front-end. - addresses = self.startup_handshake(handshake_socket, on_head_node, - vllm_config.parallel_config) - - # Update config which may have changed from the handshake - vllm_config.__post_init__() - + addresses = self.startup_handshake(handshake_socket, local_client, + parallel_config_to_update) yield addresses # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks + # We pass back the coordinator stats update address here for the + # external LB case for our colocated front-end to use (coordinator + # only runs with rank 0). + dp_stats_address = self.frontend_stats_publish_address handshake_socket.send( msgspec.msgpack.encode({ "status": "READY", - "local": on_head_node, + "local": local_client, "num_gpu_blocks": num_gpu_blocks, + "dp_stats_address": dp_stats_address, })) @staticmethod def startup_handshake( - handshake_socket: zmq.Socket, on_head_node: bool, - parallel_config: ParallelConfig) -> EngineZmqAddresses: + handshake_socket: zmq.Socket, + local_client: bool, + parallel_config: Optional[ParallelConfig] = None, + ) -> EngineZmqAddresses: # Send registration message. handshake_socket.send( msgspec.msgpack.encode({ "status": "HELLO", - "local": on_head_node, + "local": local_client, })) # Receive initialization message. @@ -466,9 +533,9 @@ def startup_handshake( init_bytes, type=EngineHandshakeMetadata) logger.debug("Received init message: %s", init_message) - received_parallel_config = init_message.parallel_config - for key, value in received_parallel_config.items(): - setattr(parallel_config, key, value) + if parallel_config is not None: + for key, value in init_message.parallel_config.items(): + setattr(parallel_config, key, value) return init_message.addresses @@ -749,12 +816,12 @@ class DPEngineCoreProc(EngineCoreProc): def __init__( self, vllm_config: VllmConfig, - on_head_node: bool, + local_client: bool, handshake_address: str, executor_class: type[Executor], log_stats: bool, + client_handshake_address: Optional[str] = None, ): - self._decorate_logs() # Counts forward-passes of the model so that we can synchronize @@ -765,8 +832,9 @@ def __init__( # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, on_head_node, handshake_address, - executor_class, log_stats, dp_rank) + super().__init__(vllm_config, local_client, handshake_address, + executor_class, log_stats, client_handshake_address, + dp_rank) def _decorate_logs(self): # Add process-specific prefix to stdout and stderr before @@ -799,6 +867,7 @@ def _init_data_parallel(self, vllm_config: VllmConfig): from vllm.platforms import current_platform device_control_env_var = current_platform.device_control_env_var world_size = vllm_config.parallel_config.world_size + # Set CUDA_VISIBLE_DEVICES or equivalent. os.environ[device_control_env_var] = ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * @@ -839,7 +908,7 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, super()._handle_client_request(request_type, request) def _maybe_publish_request_counts(self): - if not self.has_coordinator: + if not self.publish_dp_lb_stats: return # Publish our request counts (if they've changed). @@ -910,7 +979,7 @@ class DPEngineCoreActor(DPEngineCoreProc): def __init__( self, vllm_config: VllmConfig, - on_head_node: bool, + local_client: bool, addresses: EngineZmqAddresses, executor_class: type[Executor], log_stats: bool, @@ -927,15 +996,16 @@ def __init__( # data parallel groups. del os.environ['CUDA_VISIBLE_DEVICES'] - super().__init__(vllm_config, on_head_node, "", executor_class, + super().__init__(vllm_config, local_client, "", executor_class, log_stats) def _decorate_logs(self): pass @contextmanager - def _perform_handshake(self, handshake_address: str, identity: bytes, - on_head_node: bool, vllm_config: VllmConfig): + def _perform_handshakes(self, handshake_address: str, identity: bytes, + local_client: bool, vllm_config: VllmConfig, + client_handshake_address: Optional[str]): """ For Ray, we don't need to actually perform handshake. All addresses information is known before the actor creation. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 856310df5888..cc410dab19d0 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -7,7 +7,7 @@ import uuid import weakref from abc import ABC, abstractmethod -from collections import deque +from collections import defaultdict, deque from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass @@ -21,8 +21,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, - zmq_socket_ctx) +from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, + make_zmq_socket, zmq_socket_ctx) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.coordinator import DPCoordinator @@ -84,14 +84,20 @@ def make_async_mp_client( client_addresses: Optional[dict[str, str]] = None, client_index: int = 0, ) -> "MPClient": - if vllm_config.parallel_config.data_parallel_size > 1: - if vllm_config.parallel_config.data_parallel_backend == "ray": - return RayDPClient(vllm_config, executor_class, log_stats, - client_addresses, client_index) - return DPAsyncMPClient(vllm_config, executor_class, log_stats, - client_addresses, client_index) - return AsyncMPClient(vllm_config, executor_class, log_stats, - client_addresses, client_index) + parallel_config = vllm_config.parallel_config + client_args = (vllm_config, executor_class, log_stats, + client_addresses, client_index) + if parallel_config.data_parallel_size > 1: + external_lb = parallel_config.data_parallel_external_lb + if parallel_config.data_parallel_backend == "ray": + assert not external_lb, "External DP LB not supported with ray" + return RayDPClient(*client_args) + if external_lb: + # External load balancer - client per DP rank. + return DPAsyncMPClient(*client_args) + # Internal load balancer - client balances to all DP ranks. + return DPLBAsyncMPClient(*client_args) + return AsyncMPClient(*client_args) @abstractmethod def shutdown(self): @@ -392,6 +398,8 @@ def __init__( dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank + external_dp_lb = parallel_config.data_parallel_external_lb + # State used for data parallel. self.engines_running = False @@ -399,18 +407,19 @@ def __init__( # one core engine per LLM, see # examples/offline_inference/data_parallel.py. spmd_mode = local_start_index is not None - if spmd_mode: + if not spmd_mode: + local_start_index = 0 + + if spmd_mode or dp_rank > 0: assert local_engine_count == 1 self.core_engines = [CoreEngine(index=dp_rank, local=True)] else: - assert dp_rank == 0 - local_start_index = 0 self.core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] - local_only = spmd_mode or local_engine_count == dp_size + handshake_local_only = spmd_mode or local_engine_count == dp_size self.stats_update_address: Optional[str] = None if client_addresses is not None: @@ -419,9 +428,12 @@ def __init__( self.stats_update_address = client_addresses.get( "stats_update_address") else: + client_local_only = handshake_local_only or external_dp_lb host = parallel_config.data_parallel_master_ip - input_address = get_engine_client_zmq_addr(local_only, host) - output_address = get_engine_client_zmq_addr(local_only, host) + input_address = get_engine_client_zmq_addr( + client_local_only, host) + output_address = get_engine_client_zmq_addr( + client_local_only, host) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( @@ -430,15 +442,18 @@ def __init__( self.ctx, output_address, zmq.PULL) if client_addresses is None: - self._init_engines_direct(vllm_config, local_only, + self._init_engines_direct(vllm_config, handshake_local_only, local_start_index, input_address, output_address, executor_class, log_stats) coordinator = self.resources.coordinator if coordinator: - self.stats_update_address = ( + assert self.stats_update_address == ( coordinator.get_stats_publish_address()) + if external_dp_lb and dp_rank == 0: + del self.core_engines[1:] + # Wait for ready messages from each engine on the input socket. identities = set(e.identity for e in self.core_engines) sync_input_socket = zmq.Socket.shadow(self.input_socket) @@ -473,6 +488,7 @@ def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, local_engine_count = parallel_config.data_parallel_size_local start_index = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip + external_dp_lb = parallel_config.data_parallel_external_lb if len(self.core_engines) > 1: self.resources.coordinator = DPCoordinator(parallel_config) @@ -480,7 +496,15 @@ def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, handshake_address = get_engine_client_zmq_addr( local_only, host, parallel_config.data_parallel_rpc_port) - with zmq_socket_ctx(handshake_address, zmq.ROUTER, + if external_dp_lb and start_index > 0: + assert not local_only + local_handshake_address = get_open_zmq_ipc_path() + client_handshake_address = local_handshake_address + else: + local_handshake_address = handshake_address + client_handshake_address = None + + with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, bind=True) as handshake_socket: # Start local engines. @@ -493,17 +517,25 @@ def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, executor_class=executor_class, log_stats=log_stats, handshake_address=handshake_address, - on_head_node=True, + client_handshake_address=client_handshake_address, + local_client=True, local_engine_count=local_engine_count, start_index=start_index, local_start_index=local_start_index) # Wait for engine core process(es) to start. - self._wait_for_engine_startup(handshake_socket, input_address, - output_address) + return self._wait_for_engine_startup(handshake_socket, + input_address, output_address) + + def _wait_for_engine_startup( + self, + handshake_socket: zmq.Socket, + input_address: str, + output_address: str, + ) -> Optional[str]: + """Returns zmq address for front-end to subscribe to DP events from the + DP coordinator, if applicable.""" - def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, - input_address: str, output_address: str): addresses = EngineZmqAddresses( inputs=[input_address], outputs=[output_address], @@ -513,6 +545,8 @@ def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, if coordinator is not None: addresses.coordinator_input, addresses.coordinator_output = ( coordinator.get_engine_socket_addresses()) + addresses.frontend_stats_publish_address = ( + coordinator.get_stats_publish_address()) proc_manager = self.resources.engine_manager assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), ( @@ -529,6 +563,8 @@ def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, coordinator.proc if coordinator else None, ) + return addresses.frontend_stats_publish_address + def shutdown(self): # Terminate background resources. self._finalizer() @@ -583,7 +619,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], # a ref to the client which prevents gc. ctx = self.ctx out_socket = self.resources.output_socket - assert out_socket is not None decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue @@ -593,6 +628,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], resources.shutdown_path = shutdown_path def process_outputs_socket(): + assert isinstance(out_socket, zmq.Socket) shutdown_socket = ctx.socket(zmq.PAIR) try: shutdown_socket.bind(shutdown_path) @@ -609,7 +645,7 @@ def process_outputs_socket(): frames = out_socket.recv_multipart(copy=False) resources.validate_alive(frames) - outputs = decoder.decode(frames) + outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) @@ -921,7 +957,7 @@ async def collective_rpc_async( class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) - EngineCore.""" + EngineCore. Assumes external load-balancing by default.""" def __init__(self, vllm_config: VllmConfig, @@ -930,15 +966,12 @@ def __init__(self, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0): self.current_wave = 0 - # To route aborts to the correct engine. - self.reqs_in_flight: dict[str, CoreEngine] = {} super().__init__(vllm_config, executor_class, log_stats, client_addresses, client_index) - assert len(self.core_engines) > 1 - # List of [waiting, running] pair per engine. + # Used only by DPLBAsyncMPClient subclass. self.lb_engines: list[list[int]] = [] self.first_req_sock_addr = get_open_zmq_inproc_path() @@ -969,6 +1002,8 @@ async def run_engine_stats_update_task(): self.first_req_sock_addr, zmq.PAIR, bind=False) as first_req_rcv_socket: + assert isinstance(socket, zmq.asyncio.Socket) + assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket) # Send subscription message. await socket.send(b'\x01') @@ -1012,52 +1047,13 @@ async def run_engine_stats_update_task(): resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) - def get_core_engine_for_request(self, - dp_rank: Optional[int] = None - ) -> CoreEngine: - if dp_rank is not None: - # engines are already in rank order - return self.core_engines[dp_rank] - - if not self.lb_engines: - return self.core_engines[0] - # TODO use P2C alg for larger DP sizes - num_engines = len(self.lb_engines) - min_counts = [sys.maxsize, sys.maxsize] - eng_index = 0 - for i in range(num_engines): - # Start from client_index to help with balancing when engines - # are empty. - idx = (self.client_index + i) % num_engines - counts = self.lb_engines[idx] - if counts < min_counts: - min_counts = counts - eng_index = idx - # Adjust local counts for better balancing between stats updates - # from the coordinator (which happen every 100ms). - if min_counts[0]: - min_counts[0] += 1 - else: - min_counts[1] += 1 - return self.core_engines[eng_index] - - async def call_utility_async(self, method: str, *args) -> Any: - # Only the result from the first engine is returned. - return (await asyncio.gather(*[ - self._call_utility_async(method, *args, engine=engine) - for engine in self.core_engines - ]))[0] - async def add_request_async(self, request: EngineCoreRequest) -> None: self._ensure_stats_update_task() request.current_wave = self.current_wave request.client_index = self.client_index - chosen_engine = self.get_core_engine_for_request( - request.data_parallel_rank) - self.reqs_in_flight[request.request_id] = chosen_engine - + chosen_engine = self.get_core_engine_for_request(request) to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: @@ -1068,8 +1064,68 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: self._ensure_output_queue_task() + def get_core_engine_for_request(self, request: EngineCoreRequest): + return self.core_engine + + +class DPLBAsyncMPClient(DPAsyncMPClient): + """Asyncio-compatible client for multi-proc, multi-engine (data parallel) + EngineCore. Load-balances between multiple engine processes.""" + + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): + + # To route aborts to the correct engine. + self.reqs_in_flight: dict[str, CoreEngine] = {} + + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) + + assert len(self.core_engines) > 1 + + def get_core_engine_for_request(self, + request: EngineCoreRequest) -> CoreEngine: + # Engines are in rank order. + if (eng_index := request.data_parallel_rank) is None: + if not self.lb_engines: + return self.core_engine + # TODO use P2C alg for larger DP sizes + num_engines = len(self.lb_engines) + min_counts = [sys.maxsize, sys.maxsize] + eng_index = 0 + for i in range(num_engines): + # Start from client_index to help with balancing when engines + # are empty. + idx = (self.client_index + i) % num_engines + counts = self.lb_engines[idx] + if counts < min_counts: + min_counts = counts + eng_index = idx + # Adjust local counts for better balancing between stats updates + # from the coordinator (which happen every 100ms). + if min_counts[0]: + min_counts[0] += 1 + else: + min_counts[1] += 1 + + chosen_engine = self.core_engines[eng_index] + # Record which engine is chosen for this request, to handle aborts. + self.reqs_in_flight[request.request_id] = chosen_engine + return chosen_engine + + async def call_utility_async(self, method: str, *args) -> Any: + # Only the result from the first engine is returned. + return (await asyncio.gather(*[ + self._call_utility_async(method, *args, engine=engine) + for engine in self.core_engines + ]))[0] + @staticmethod - async def process_engine_outputs(self: "DPAsyncMPClient", + async def process_engine_outputs(self: "DPLBAsyncMPClient", outputs: EngineCoreOutputs): if outputs.finished_requests and self.reqs_in_flight: for req_id in outputs.finished_requests: @@ -1085,10 +1141,10 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: await self._abort_requests(request_ids, engine) return - by_engine: dict[CoreEngine, list[str]] = {} + by_engine = defaultdict[CoreEngine, list[str]](list) for req_id in request_ids: if engine := self.reqs_in_flight.get(req_id): - by_engine.setdefault(engine, []).append(req_id) + by_engine[engine].append(req_id) for engine, req_ids in by_engine.items(): await self._abort_requests(req_ids, engine) @@ -1098,7 +1154,7 @@ async def _abort_requests(self, request_ids: list[str], engine) -class RayDPClient(DPAsyncMPClient): +class RayDPClient(DPLBAsyncMPClient): """ Ray-based client for multi-proc, multi-engine (data parallel) EngineCore. diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 192c9067740c..56d2e3d975d2 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -111,6 +111,14 @@ def __repr__(self): def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: + """Assign a new ZMQ socket address. + + If local_only is True, participants are colocated and so a unique IPC + address will be returned. + + Otherwise, the provided host and port will be used to construct a TCP + address (port == 0 means assign an available port).""" + return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( host, port or get_open_port())) @@ -142,6 +150,10 @@ class EngineZmqAddresses: coordinator_input: Optional[str] = None # ZMQ output socket address of DP coordinator if applicable coordinator_output: Optional[str] = None + # ZMQ socket for front-end to connect to DP coordinator. + # Not used by engine, just relayed to front-end in handshake response. + # Only required for external DP LB case. + frontend_stats_publish_address: Optional[str] = None @dataclass @@ -232,20 +244,25 @@ def __init__( start_index: int, local_start_index: int, vllm_config: VllmConfig, - on_head_node: bool, + local_client: bool, handshake_address: str, executor_class: type[Executor], log_stats: bool, + client_handshake_address: Optional[str] = None, ): context = get_mp_context() common_kwargs = { "vllm_config": vllm_config, - "on_head_node": on_head_node, + "local_client": local_client, "handshake_address": handshake_address, "executor_class": executor_class, "log_stats": log_stats, } + if client_handshake_address: + common_kwargs[ + "client_handshake_address"] = client_handshake_address + self.processes: list[BaseProcess] = [] for index in range(local_engine_count): local_index = local_start_index + index @@ -349,7 +366,7 @@ def __init__( dp_vllm_config = copy.deepcopy(vllm_config) pg = placement_groups[index] dp_vllm_config.parallel_config.placement_group = pg - on_head_node = index < local_engine_count + local_client = index < local_engine_count actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, @@ -357,11 +374,11 @@ def __init__( )).remote(vllm_config=dp_vllm_config, executor_class=executor_class, log_stats=log_stats, - on_head_node=on_head_node, + local_client=local_client, addresses=addresses, dp_rank=index, local_dp_rank=local_index) - if on_head_node: + if local_client: self.local_engine_actors.append(actor) else: self.remote_engine_actors.append(actor) @@ -460,7 +477,6 @@ def wait_for_engine_startup( proc_manager: Optional[CoreEngineProcManager], coord_process: Optional[Process], ): - # Wait for engine core process(es) to send ready messages. local_count = parallel_config.data_parallel_size_local remote_count = len(core_engines) - local_count @@ -537,6 +553,14 @@ def wait_for_engine_startup( num_gpu_blocks += msg["num_gpu_blocks"] cache_config.num_gpu_blocks = num_gpu_blocks + # In external DP LB mode, the coordinator address that the + # front-end procs connect to is obtained from rank 0 via + # one of the engine handshakes, and passed to the local + # front-end process in the response from the other. + if addresses.frontend_stats_publish_address is None: + addresses.frontend_stats_publish_address = msg.get( + "dp_stats_address") + start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY else: From 666c197b7f1e00f411e4de58bb2ac9589e0eaa72 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 24 Jun 2025 14:36:14 -0700 Subject: [PATCH 02/12] deduplicate/unify engine launching logic Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core_client.py | 2 +- vllm/entrypoints/cli/serve.py | 175 +----- vllm/v1/engine/core.py | 6 +- vllm/v1/engine/core_client.py | 250 ++------ vllm/v1/engine/utils.py | 684 +++++++++++++++++++++ vllm/v1/utils.py | 543 +--------------- 6 files changed, 774 insertions(+), 886 deletions(-) create mode 100644 vllm/v1/engine/utils.py diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index d5ff78c1449a..79ce5b126db0 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -26,8 +26,8 @@ from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, SyncMPClient) +from vllm.v1.engine.utils import CoreEngineProcManager from vllm.v1.executor.abstract import Executor -from vllm.v1.utils import CoreEngineProcManager from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 35b3942d506e..ec413748020a 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -8,7 +8,6 @@ from typing import Optional import uvloop -import zmq import vllm import vllm.envs as envs @@ -22,18 +21,13 @@ from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, - get_tcp_uri, zmq_socket_ctx) -from vllm.v1.engine.coordinator import DPCoordinator +from vllm.utils import FlexibleArgumentParser, get_tcp_uri from vllm.v1.engine.core import EngineCoreProc -from vllm.v1.engine.core_client import CoreEngineProcManager +from vllm.v1.engine.utils import (APIServerProcessManager, + CoreEngineProcManager, launch_core_engines, + wait_for_completion_or_failure) from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus -from vllm.v1.utils import (APIServerProcessManager, CoreEngine, - CoreEngineActorManager, EngineZmqAddresses, - get_engine_client_zmq_addr, - wait_for_completion_or_failure, - wait_for_engine_startup) logger = init_logger(__name__) @@ -194,116 +188,35 @@ def run_multi_api_server(args: argparse.Namespace): " api_server_count > 1") model_config.disable_mm_preprocessor_cache = True - parallel_config = vllm_config.parallel_config + executor_class = Executor.get_class(vllm_config) + log_stats = not engine_args.disable_log_stats - dp_size = parallel_config.data_parallel_size + parallel_config = vllm_config.parallel_config dp_rank = parallel_config.data_parallel_rank - local_engine_count = parallel_config.data_parallel_size_local - host = parallel_config.data_parallel_master_ip external_dp_lb = parallel_config.data_parallel_external_lb - - # The handshake socket can be IPC if this front-end is colocated - # with all engines. - handshake_local_only = local_engine_count == dp_size - - # client_local_only = True for cases where this front-end - # sends requests only to colocated engines. - client_local_only = handshake_local_only or external_dp_lb - assert external_dp_lb or dp_rank == 0 - # Set up input and output addresses. - input_addresses = [ - get_engine_client_zmq_addr(client_local_only, host) - for _ in range(num_api_servers) - ] - output_addresses = [ - get_engine_client_zmq_addr(client_local_only, host) - for _ in range(num_api_servers) - ] - - addresses = EngineZmqAddresses( - inputs=input_addresses, - outputs=output_addresses, - ) - - # Set up coordinator for dp > 1. - # This runs only with the rank 0 engine, or when not colocated with any - # engines (in which case the value of dp_rank here will also be 0). - coordinator = None - stats_update_address = None - if dp_size > 1 and dp_rank == 0: - coordinator = DPCoordinator(parallel_config) - addresses.coordinator_input, addresses.coordinator_output = ( - coordinator.get_engine_socket_addresses()) - addresses.frontend_stats_publish_address = ( - coordinator.get_stats_publish_address()) - stats_update_address = coordinator.get_stats_publish_address() - logger.info("Started DP Coordinator process (PID: %d)", - coordinator.proc.pid) - api_server_manager: Optional[APIServerProcessManager] = None - # Construct common args for the APIServerProcessManager up-front. - api_server_manager_kwargs = dict( - target_server_fn=run_api_server_worker_proc, - listen_address=listen_address, - sock=sock, - args=args, - num_servers=num_api_servers, - input_addresses=input_addresses, - output_addresses=output_addresses, - stats_update_address=stats_update_address) - - if parallel_config.data_parallel_backend == "ray": - logger.info("Starting ray-based data parallel backend") - - engine_actor_manager = CoreEngineActorManager( - vllm_config=vllm_config, - addresses=addresses, - executor_class=Executor.get_class(vllm_config), - log_stats=not engine_args.disable_log_stats, - ) - # Start API servers using the manager - api_server_manager = APIServerProcessManager( - **api_server_manager_kwargs) - - wait_for_completion_or_failure(api_server_manager=api_server_manager, - engine_manager=engine_actor_manager, - coordinator=coordinator) - return - - handshake_address = get_engine_client_zmq_addr( - handshake_local_only, host, parallel_config.data_parallel_rpc_port) - if external_dp_lb and dp_rank > 0: - assert not handshake_local_only - local_handshake_address = get_open_zmq_ipc_path() - client_handshake_address = local_handshake_address - else: - local_handshake_address = handshake_address - client_handshake_address = None - - with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, - bind=True) as handshake_socket: - - # Start local engines. - if not local_engine_count: - local_engine_manager = None - else: - local_engine_manager = CoreEngineProcManager( - EngineCoreProc.run_engine_core, - vllm_config=vllm_config, - executor_class=Executor.get_class(vllm_config), - log_stats=not engine_args.disable_log_stats, - handshake_address=handshake_address, - client_handshake_address=client_handshake_address, - local_client=True, - local_engine_count=local_engine_count, - start_index=dp_rank, - local_start_index=0) + with launch_core_engines(vllm_config, executor_class, log_stats, + num_api_servers) as (local_engine_manager, + coordinator, addresses): + + # Construct common args for the APIServerProcessManager up-front. + api_server_manager_kwargs = dict( + target_server_fn=run_api_server_worker_proc, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=addresses.inputs, + output_addresses=addresses.outputs, + stats_update_address=coordinator.get_stats_publish_address() + if coordinator else None) # For dp ranks > 0 in external DP LB mode, we must delay the - # start of the API servers until the local engine is started, + # start of the API servers until the local engine is started + # (after the launcher context manager exits), # since we get the front-end stats update address from the coordinator # via the handshake with the local engine. if dp_rank == 0 or not external_dp_lb: @@ -311,37 +224,17 @@ def run_multi_api_server(args: argparse.Namespace): api_server_manager = APIServerProcessManager( **api_server_manager_kwargs) - # Wait for engine handshakes to complete. - if external_dp_lb and dp_rank > 0: - assert local_engine_count == 1 - core_engines = [CoreEngine(index=dp_rank, local=True)] - else: - core_engines = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) - ] - - wait_for_engine_startup( - handshake_socket, - addresses, - core_engines, - parallel_config, - vllm_config.cache_config, - local_engine_manager, - coordinator.proc if coordinator else None, - ) - - # Start API servers now if they weren't already started. - if api_server_manager is None: - api_server_manager_kwargs["stats_update_address"] = ( - addresses.frontend_stats_publish_address) - api_server_manager = APIServerProcessManager( - **api_server_manager_kwargs) + # Start API servers now if they weren't already started. + if api_server_manager is None: + api_server_manager_kwargs["stats_update_address"] = ( + addresses.frontend_stats_publish_address) + api_server_manager = APIServerProcessManager( + **api_server_manager_kwargs) - # Wait for API servers - wait_for_completion_or_failure(api_server_manager=api_server_manager, - engine_manager=local_engine_manager, - coordinator=coordinator) + # Wait for API servers + wait_for_completion_or_failure(api_server_manager=api_server_manager, + engine_manager=local_engine_manager, + coordinator=coordinator) def run_api_server_worker_proc(listen_address, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4a468680793e..d89e6ae7433c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -34,6 +34,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache +from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -41,7 +42,6 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -961,9 +961,9 @@ def run_busy_loop(self): def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - # Optimization - only perform finish-sync all-reduce every 24 steps. + # Optimization - only perform finish-sync all-reduce every 32 steps. self.counter += 1 - if self.counter != 24: + if self.counter != 32: return True self.counter = 0 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index cc410dab19d0..dafaa15f777d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -21,18 +21,16 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - make_zmq_socket, zmq_socket_ctx) +from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError +from vllm.v1.engine.utils import (CoreEngineActorManager, + CoreEngineProcManager, launch_core_engines) from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import (CoreEngine, CoreEngineActorManager, - CoreEngineProcManager, EngineZmqAddresses, - get_engine_client_zmq_addr, wait_for_engine_startup) logger = init_logger(__name__) @@ -40,6 +38,8 @@ _R = TypeVar('_R') # Return type for collective_rpc +EngineIdentity = bytes + class EngineCoreClient(ABC): """ @@ -88,11 +88,7 @@ def make_async_mp_client( client_args = (vllm_config, executor_class, log_stats, client_addresses, client_index) if parallel_config.data_parallel_size > 1: - external_lb = parallel_config.data_parallel_external_lb - if parallel_config.data_parallel_backend == "ray": - assert not external_lb, "External DP LB not supported with ray" - return RayDPClient(*client_args) - if external_lb: + if parallel_config.data_parallel_external_lb: # External load balancer - client per DP rank. return DPAsyncMPClient(*client_args) # Internal load balancer - client balances to all DP ranks. @@ -392,48 +388,32 @@ def __init__( self._finalizer = weakref.finalize(self, self.resources) success = False try: - parallel_config = vllm_config.parallel_config - local_engine_count = parallel_config.data_parallel_size_local - local_start_index = parallel_config.data_parallel_rank_local - dp_size = parallel_config.data_parallel_size - dp_rank = parallel_config.data_parallel_rank - - external_dp_lb = parallel_config.data_parallel_external_lb - # State used for data parallel. self.engines_running = False - # SPMD mode is where there is an LLM instance per DP rank and - # one core engine per LLM, see - # examples/offline_inference/data_parallel.py. - spmd_mode = local_start_index is not None - if not spmd_mode: - local_start_index = 0 - - if spmd_mode or dp_rank > 0: - assert local_engine_count == 1 - self.core_engines = [CoreEngine(index=dp_rank, local=True)] - else: - self.core_engines = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) - ] - - handshake_local_only = spmd_mode or local_engine_count == dp_size - self.stats_update_address: Optional[str] = None if client_addresses is not None: + # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get( "stats_update_address") else: - client_local_only = handshake_local_only or external_dp_lb - host = parallel_config.data_parallel_master_ip - input_address = get_engine_client_zmq_addr( - client_local_only, host) - output_address = get_engine_client_zmq_addr( - client_local_only, host) + # Engines are managed by this client. + with launch_core_engines(vllm_config, executor_class, + log_stats) as (engine_manager, + coordinator, + addresses): + self.resources.coordinator = coordinator + self.resources.engine_manager = engine_manager + + (input_address, ) = addresses.inputs + (output_address, ) = addresses.outputs + self.stats_update_address = ( + addresses.frontend_stats_publish_address) + if coordinator is not None: + assert self.stats_update_address == ( + coordinator.get_stats_publish_address()) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( @@ -441,21 +421,24 @@ def __init__( self.resources.output_socket = make_zmq_socket( self.ctx, output_address, zmq.PULL) - if client_addresses is None: - self._init_engines_direct(vllm_config, handshake_local_only, - local_start_index, input_address, - output_address, executor_class, - log_stats) - coordinator = self.resources.coordinator - if coordinator: - assert self.stats_update_address == ( - coordinator.get_stats_publish_address()) + parallel_config = vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + external_dp_lb = parallel_config.data_parallel_external_lb + + offline_mode = parallel_config.data_parallel_rank_local is not None + engine_ranks = [dp_rank] if (offline_mode + or external_dp_lb) else range(dp_size) + assert parallel_config.data_parallel_size_local <= len( + engine_ranks) - if external_dp_lb and dp_rank == 0: - del self.core_engines[1:] + # ZMQ identity of each engine that this client will talk to. + self.core_engines: list[EngineIdentity] = [ + index.to_bytes(2, "little") for index in engine_ranks + ] # Wait for ready messages from each engine on the input socket. - identities = set(e.identity for e in self.core_engines) + identities = set(self.core_engines) sync_input_socket = zmq.Socket.shadow(self.input_socket) while identities: if not sync_input_socket.poll(timeout=600_000): @@ -464,7 +447,7 @@ def __init__( identity, _ = sync_input_socket.recv_multipart() identities.remove(identity) - self.core_engine = self.core_engines[0] + self.core_engine: EngineIdentity = self.core_engines[0] self.utility_results: dict[int, AnyFuture] = {} # Request objects which may contain pytorch-allocated tensors @@ -477,94 +460,6 @@ def __init__( if not success: self._finalizer() - def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, - local_start_index: int, input_address: str, - output_address: str, - executor_class: type[Executor], log_stats: bool): - """Self-contained client mode, launch engine and coordinator process - as needed.""" - - parallel_config = vllm_config.parallel_config - local_engine_count = parallel_config.data_parallel_size_local - start_index = parallel_config.data_parallel_rank - host = parallel_config.data_parallel_master_ip - external_dp_lb = parallel_config.data_parallel_external_lb - - if len(self.core_engines) > 1: - self.resources.coordinator = DPCoordinator(parallel_config) - - handshake_address = get_engine_client_zmq_addr( - local_only, host, parallel_config.data_parallel_rpc_port) - - if external_dp_lb and start_index > 0: - assert not local_only - local_handshake_address = get_open_zmq_ipc_path() - client_handshake_address = local_handshake_address - else: - local_handshake_address = handshake_address - client_handshake_address = None - - with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, - bind=True) as handshake_socket: - - # Start local engines. - if local_engine_count: - # In server mode, start_index and local_start_index will - # both be 0. - self.resources.engine_manager = CoreEngineProcManager( - EngineCoreProc.run_engine_core, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=log_stats, - handshake_address=handshake_address, - client_handshake_address=client_handshake_address, - local_client=True, - local_engine_count=local_engine_count, - start_index=start_index, - local_start_index=local_start_index) - - # Wait for engine core process(es) to start. - return self._wait_for_engine_startup(handshake_socket, - input_address, output_address) - - def _wait_for_engine_startup( - self, - handshake_socket: zmq.Socket, - input_address: str, - output_address: str, - ) -> Optional[str]: - """Returns zmq address for front-end to subscribe to DP events from the - DP coordinator, if applicable.""" - - addresses = EngineZmqAddresses( - inputs=[input_address], - outputs=[output_address], - ) - - coordinator = self.resources.coordinator - if coordinator is not None: - addresses.coordinator_input, addresses.coordinator_output = ( - coordinator.get_engine_socket_addresses()) - addresses.frontend_stats_publish_address = ( - coordinator.get_stats_publish_address()) - - proc_manager = self.resources.engine_manager - assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), ( - "_wait_for_engine_startup should only be " - "called with CoreEngineProcManager") - - wait_for_engine_startup( - handshake_socket, - addresses, - self.core_engines, - self.vllm_config.parallel_config, - self.vllm_config.cache_config, - proc_manager, - coordinator.proc if coordinator else None, - ) - - return addresses.frontend_stats_publish_address - def shutdown(self): # Terminate background resources. self._finalizer() @@ -682,7 +577,7 @@ def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() self.free_pending_messages() # (Identity, RequestType, SerializedRequest) - msg = (self.core_engine.identity, request_type.value, + msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) if len(msg) <= 3: @@ -848,7 +743,7 @@ async def get_output_async(self) -> EngineCoreOutputs: def _send_input(self, request_type: EngineCoreRequestType, request: Any, - engine: Optional[CoreEngine] = None) -> Awaitable[Any]: + engine: Optional[EngineIdentity] = None) -> Awaitable[Any]: if engine is None: engine = self.core_engine @@ -856,7 +751,7 @@ def _send_input(self, return self._send_input_message(message, engine, request) def _send_input_message(self, message: tuple[bytestr, - ...], engine: CoreEngine, + ...], engine: EngineIdentity, objects: Any) -> Awaitable[Any]: """ objects is a reference to retain until zmq is finished with the @@ -865,7 +760,7 @@ def _send_input_message(self, message: tuple[bytestr, self.ensure_alive() self.free_pending_messages() - msg = (engine.identity, ) + message + msg = (engine, ) + message if not objects or len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. return self.input_socket.send_multipart(msg, copy=False) @@ -886,7 +781,7 @@ async def call_utility_async(self, method: str, *args) -> Any: engine=self.core_engine) async def _call_utility_async(self, method: str, *args, - engine: CoreEngine) -> Any: + engine: EngineIdentity) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future @@ -1058,7 +953,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: chosen_engine) if not self.engines_running: # Notify coordinator that we're sending a request - await self.first_req_send_socket.send(chosen_engine.identity) + await self.first_req_send_socket.send(chosen_engine) await to_await @@ -1080,15 +975,15 @@ def __init__(self, client_index: int = 0): # To route aborts to the correct engine. - self.reqs_in_flight: dict[str, CoreEngine] = {} + self.reqs_in_flight: dict[str, EngineIdentity] = {} super().__init__(vllm_config, executor_class, log_stats, client_addresses, client_index) assert len(self.core_engines) > 1 - def get_core_engine_for_request(self, - request: EngineCoreRequest) -> CoreEngine: + def get_core_engine_for_request( + self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. if (eng_index := request.data_parallel_rank) is None: if not self.lb_engines: @@ -1141,7 +1036,7 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: await self._abort_requests(request_ids, engine) return - by_engine = defaultdict[CoreEngine, list[str]](list) + by_engine = defaultdict[EngineIdentity, list[str]](list) for req_id in request_ids: if engine := self.reqs_in_flight.get(req_id): by_engine[engine].append(req_id) @@ -1149,53 +1044,6 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: await self._abort_requests(req_ids, engine) async def _abort_requests(self, request_ids: list[str], - engine: CoreEngine) -> None: + engine: EngineIdentity) -> None: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) - - -class RayDPClient(DPLBAsyncMPClient): - """ - Ray-based client for multi-proc, multi-engine (data parallel) - EngineCore. - """ - - def __init__( - self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_index: int = 0, - ): - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_index) - - def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, - local_start_index: int, input_address: str, - output_address: str, - executor_class: type[Executor], log_stats: bool): - """Self-contained client mode, launch engine and coordinator process - as needed.""" - - parallel_config = vllm_config.parallel_config - assert parallel_config.data_parallel_rank == 0 - assert local_start_index == 0 - - addresses = EngineZmqAddresses( - inputs=[input_address], - outputs=[output_address], - ) - - if len(self.core_engines) > 1: - coordinator = DPCoordinator(parallel_config) - self.resources.coordinator = coordinator - addresses.coordinator_input, addresses.coordinator_output = ( - coordinator.get_engine_socket_addresses()) - - # Start all engines. - self.resources.engine_manager = CoreEngineActorManager( - vllm_config=vllm_config, - addresses=addresses, - executor_class=executor_class, - log_stats=log_stats) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py new file mode 100644 index 000000000000..4ab5527bfcb1 --- /dev/null +++ b/vllm/v1/engine/utils.py @@ -0,0 +1,684 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import contextlib +import multiprocessing +import weakref +from collections.abc import Iterator +from dataclasses import dataclass +from enum import Enum, auto +from multiprocessing import Process, connection +from multiprocessing.process import BaseProcess +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +import msgspec +import zmq + +from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx +from vllm.v1.engine.coordinator import DPCoordinator +from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import get_engine_client_zmq_addr, shutdown + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + +STARTUP_POLL_PERIOD_MS = 10000 + + +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + +class CoreEngine: + """One per data parallel rank, used to track state during handshaking.""" + + def __init__(self, index: int = 0, local: bool = True): + self.local = local + self.identity = index.to_bytes(2, "little") + + self.state = CoreEngineState.NEW + + +@dataclass +class EngineZmqAddresses: + # ZMQ input socket addresses for each front-end client (requests) + inputs: list[str] + # ZMQ output socket addresses for each front-end client (responses) + outputs: list[str] + # ZMQ input socket address of DP coordinator if applicable + coordinator_input: Optional[str] = None + # ZMQ output socket address of DP coordinator if applicable + coordinator_output: Optional[str] = None + # ZMQ socket for front-end to connect to DP coordinator. + # Not used by engine, just relayed to front-end in handshake response. + # Only required for external DP LB case. + frontend_stats_publish_address: Optional[str] = None + + +@dataclass +class EngineHandshakeMetadata: + """Metadata sent to each engine process during startup handshake, + including addresses of the front-end ZMQ queues that they should + connect to. + """ + addresses: EngineZmqAddresses + parallel_config: dict[str, Union[int, str]] + + +class APIServerProcessManager: + """Manages a group of API server processes. + + Handles creation, monitoring, and termination of API server worker + processes. Also monitors extra processes to check if they are healthy. + """ + + def __init__( + self, + target_server_fn: Callable, + listen_address: str, + sock: Any, + args: argparse.Namespace, + num_servers: int, + input_addresses: list[str], + output_addresses: list[str], + stats_update_address: Optional[str] = None, + ): + """Initialize and start API server worker processes. + + Args: + target_server_fn: Function to call for each API server process + listen_address: Address to listen for client connections + sock: Socket for client connections + args: Command line arguments + num_servers: Number of API server processes to start + input_addresses: Input addresses for each API server + output_addresses: Output addresses for each API server + stats_update_address: Optional stats update address + """ + self.listen_address = listen_address + self.sock = sock + self.args = args + + # Start API servers + spawn_context = multiprocessing.get_context("spawn") + self.processes: list[BaseProcess] = [] + + for i, in_addr, out_addr in zip(range(num_servers), input_addresses, + output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + proc = spawn_context.Process(target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + self.processes.append(proc) + proc.start() + + logger.info("Started %d API server processes", len(self.processes)) + + # Shutdown only the API server processes on garbage collection + # The extra processes are managed by their owners + self._finalizer = weakref.finalize(self, shutdown, self.processes) + + def close(self) -> None: + self._finalizer() + + +class CoreEngineProcManager: + """ + Utility class to handle creation, readiness, and shutdown + of background processes used by the AsyncLLM and LLMEngine. + """ + + def __init__( + self, + target_fn: Callable, + local_engine_count: int, + start_index: int, + local_start_index: int, + vllm_config: VllmConfig, + local_client: bool, + handshake_address: str, + executor_class: type[Executor], + log_stats: bool, + client_handshake_address: Optional[str] = None, + ): + context = get_mp_context() + common_kwargs = { + "vllm_config": vllm_config, + "local_client": local_client, + "handshake_address": handshake_address, + "executor_class": executor_class, + "log_stats": log_stats, + } + + if client_handshake_address: + common_kwargs[ + "client_handshake_address"] = client_handshake_address + + self.processes: list[BaseProcess] = [] + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index + # Start EngineCore in background process. + self.processes.append( + context.Process(target=target_fn, + name=f"EngineCore_{global_index}", + kwargs=common_kwargs | { + "dp_rank": global_index, + "local_dp_rank": local_index, + })) + + self._finalizer = weakref.finalize(self, shutdown, self.processes) + try: + for proc in self.processes: + proc.start() + finally: + # Kill other procs if not all are running. + if self.finished_procs(): + self.close() + + def close(self): + """Shutdown all procs.""" + self._finalizer() + + def join_first(self): + """Wait for any process to exit.""" + connection.wait(proc.sentinel for proc in self.processes) + + def sentinels(self) -> list: + return [proc.sentinel for proc in self.processes] + + def finished_procs(self) -> dict[str, int]: + """Returns dict of proc name -> exit code for any finished procs.""" + return { + proc.name: proc.exitcode + for proc in self.processes if proc.exitcode is not None + } + + +class CoreEngineActorManager: + """ + Utility class to handle creation, readiness, and shutdown + of core engine Ray actors used by the AsyncLLM and LLMEngine. + + Different from CoreEngineProcManager, this class manages + core engines for both local and remote nodes. + """ + + def __init__( + self, + vllm_config: VllmConfig, + addresses: EngineZmqAddresses, + executor_class: type[Executor], + log_stats: bool, + placement_groups: Optional[list["PlacementGroup"]] = None, + local_dp_ranks: Optional[list[int]] = None, + ): + import copy + + import ray + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + + from vllm.v1.engine.core import DPEngineCoreActor + + self.local_engine_actors: list[ray.ActorHandle] = [] + self.remote_engine_actors: list[ray.ActorHandle] = [] + dp_size = vllm_config.parallel_config.data_parallel_size + local_engine_count = \ + vllm_config.parallel_config.data_parallel_size_local + world_size = vllm_config.parallel_config.world_size + + if ray.is_initialized(): + logger.info( + "Ray is already initialized. Skipping Ray initialization.") + else: + ray.init() + + if placement_groups is not None: + assert local_dp_ranks is not None, ( + "local_dp_ranks must be provided if " + "placement_groups is provided") + assert len(placement_groups) == len(local_dp_ranks), ( + "placement_groups and local_dp_ranks must " + "have the same length") + logger.info("Using provided placement groups") + # TODO(rui): validate passed-in placement groups + self.created_placement_groups = [] + else: + placement_groups, local_dp_ranks = \ + CoreEngineActorManager.create_dp_placement_groups(vllm_config) + self.created_placement_groups = placement_groups + assert len(placement_groups) == dp_size, ( + "Number of placement groups must match data parallel size") + + refs = [] + for index in range(dp_size): + local_index = local_dp_ranks[index] + dp_vllm_config = copy.deepcopy(vllm_config) + pg = placement_groups[index] + dp_vllm_config.parallel_config.placement_group = pg + local_client = index < local_engine_count + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + )).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + local_client=local_client, + addresses=addresses, + dp_rank=index, + local_dp_rank=local_index) + if local_client: + self.local_engine_actors.append(actor) + else: + self.remote_engine_actors.append(actor) + refs.append(actor.wait_for_init.remote()) + + ray.get(refs) + self.run_refs = [] + for actor in self.local_engine_actors + self.remote_engine_actors: + self.run_refs.append(actor.run.remote()) + + @staticmethod + def create_dp_placement_groups( + vllm_config: VllmConfig + ) -> tuple[list["PlacementGroup"], list[int]]: + + import ray + from ray._private.state import available_resources_per_node + from ray.util.state import list_nodes + + logger.info("Creating placement groups for data parallel") + dp_master_ip = \ + vllm_config.parallel_config.data_parallel_master_ip + dp_size = vllm_config.parallel_config.data_parallel_size + local_engine_count = \ + vllm_config.parallel_config.data_parallel_size_local + + nodes = sorted(list_nodes(), + key=lambda node: node.node_ip != dp_master_ip) + assert nodes[0].node_ip == dp_master_ip, ( + "The first node must be the head node") + assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( + "There can only be one head node") + + available_resources = available_resources_per_node() + world_size = vllm_config.parallel_config.world_size + placement_groups: list[PlacementGroup] = [] + local_dp_ranks: list[int] = [] + + for node in nodes: + node_ip = node.node_ip + node_resources = available_resources[node.node_id] + # For now, each DP rank can only be assigned to one node + # TODO(rui): support allocating a single DP rank + # to multiple nodes + available_engine_count = int(node_resources["GPU"]) // world_size + if node_ip == dp_master_ip: + assert available_engine_count >= local_engine_count, ( + "Not enough resources to allocate DP ranks " + f"on DP master node {node_ip}") + for i in range(local_engine_count): + bundles = [{ + "GPU": 1.0, + "node:" + dp_master_ip: 0.001 + }] * world_size + [{ + "CPU": 1.0 + }] + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + else: + for i in range(available_engine_count): + if len(placement_groups) == dp_size: + break + bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + return placement_groups, local_dp_ranks + + def get_run_refs(self): + return self.run_refs + + def close(self): + import ray + for actor in self.local_engine_actors + self.remote_engine_actors: + ray.kill(actor) + for pg in self.created_placement_groups: + ray.util.remove_placement_group(pg) + + +@contextlib.contextmanager +def launch_core_engines( + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + num_api_servers: int = 1, +) -> Iterator[tuple[ + Optional[Union[CoreEngineProcManager, CoreEngineActorManager]], + Optional[DPCoordinator], + EngineZmqAddresses, +]]: + """Launch engine and DP coordinator processes as needed.""" + + parallel_config = vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + local_start_index = parallel_config.data_parallel_rank_local + dp_rank = parallel_config.data_parallel_rank + host = parallel_config.data_parallel_master_ip + external_dp_lb = parallel_config.data_parallel_external_lb + + # In offline mode there is an LLM instance per DP rank and + # one core engine per LLM, see + # examples/offline_inference/data_parallel.py. + offline_mode = local_start_index is not None + + # client_local_only = True for cases where this front-end + # sends requests only to colocated engines. + client_local_only = offline_mode or external_dp_lb or (local_engine_count + == dp_size) + + # Set up input and output addresses. + addresses = EngineZmqAddresses( + inputs=[ + get_engine_client_zmq_addr(client_local_only, host) + for _ in range(num_api_servers) + ], + outputs=[ + get_engine_client_zmq_addr(client_local_only, host) + for _ in range(num_api_servers) + ], + ) + + # Run the DP Coordinator process with rank 0 when in + # online DP mode. + run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0 + + if run_coordinator: + coordinator = DPCoordinator(parallel_config) + + addresses.coordinator_input, addresses.coordinator_output = ( + coordinator.get_engine_socket_addresses()) + addresses.frontend_stats_publish_address = ( + coordinator.get_stats_publish_address()) + + logger.info("Started DP Coordinator process (PID: %d)", + coordinator.proc.pid) + else: + coordinator = None + + if parallel_config.data_parallel_backend == "ray": + logger.info("Starting ray-based data parallel backend") + + engine_actor_manager = CoreEngineActorManager( + vllm_config=vllm_config, + addresses=addresses, + executor_class=executor_class, + log_stats=log_stats, + ) + + yield engine_actor_manager, coordinator, addresses + return + + if offline_mode or (external_dp_lb and dp_rank > 0): + assert local_engine_count == 1 + engines_to_handshake = [CoreEngine(index=dp_rank, local=True)] + else: + engines_to_handshake = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] + + # Whether the started engines will handshake only with co-located + # front-end processes. In external_dp_lb mode, ranks > 0 handshake with + # their co-located frontend and also the rank 0 front-end, and hence this + # will be False. + handshake_local_only = offline_mode or local_engine_count == dp_size + + handshake_address = get_engine_client_zmq_addr( + handshake_local_only, host, parallel_config.data_parallel_rpc_port) + + if external_dp_lb and dp_rank > 0: + assert not handshake_local_only + local_handshake_address = get_open_zmq_ipc_path() + client_handshake_address = local_handshake_address + else: + local_handshake_address = handshake_address + client_handshake_address = None + + with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + from vllm.v1.engine.core import EngineCoreProc + + # Start local engines. + if local_engine_count: + # In server mode, start_index and local_start_index will + # both be 0. + local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + handshake_address=handshake_address, + client_handshake_address=client_handshake_address, + local_client=True, + local_engine_count=local_engine_count, + start_index=dp_rank, + local_start_index=local_start_index or 0) + else: + local_engine_manager = None + + yield local_engine_manager, coordinator, addresses + + # Now wait for engines to start. + wait_for_engine_startup( + handshake_socket, + addresses, + engines_to_handshake, + parallel_config, + vllm_config.cache_config, + local_engine_manager, + coordinator.proc if coordinator else None, + ) + + +def wait_for_engine_startup( + handshake_socket: zmq.Socket, + addresses: EngineZmqAddresses, + core_engines: list[CoreEngine], + parallel_config: ParallelConfig, + cache_config: CacheConfig, + proc_manager: Optional[CoreEngineProcManager], + coord_process: Optional[Process], +): + # Wait for engine core process(es) to send ready messages. + local_count = parallel_config.data_parallel_size_local + remote_count = len(core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() + poller.register(handshake_socket, zmq.POLLIN) + + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + if coord_process is not None: + poller.register(coord_process.sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): + events = poller.poll(STARTUP_POLL_PERIOD_MS) + if not events: + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) + continue + if len(events) > 1 or events[0][0] != handshake_socket: + # One of the local core processes exited. + finished = proc_manager.finished_procs() if proc_manager else {} + if coord_process is not None and coord_process.exitcode is not None: + finished[coord_process.name] = coord_process.exitcode + raise RuntimeError("Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, "little") + engine = next((e for e in core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = msgspec.msgpack.encode( + EngineHandshakeMetadata( + addresses=addresses, + parallel_config={ + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + })) + handshake_socket.send_multipart((eng_identity, init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and engine.state == CoreEngineState.CONNECTED: + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg["num_gpu_blocks"] + cache_config.num_gpu_blocks = num_gpu_blocks + + # In external DP LB mode, the coordinator address that the + # front-end procs connect to is obtained from rank 0 via + # one of the engine handshakes, and passed to the local + # front-end process in the response from the other. + if addresses.frontend_stats_publish_address is None: + addresses.frontend_stats_publish_address = msg.get( + "dp_stats_address") + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) + + +def wait_for_completion_or_failure( + api_server_manager: APIServerProcessManager, + engine_manager: Optional[Union[CoreEngineProcManager, + CoreEngineActorManager]] = None, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all processes to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + + Args: + api_server_manager: The manager for API servers. + engine_manager: The manager for engine processes. + If CoreEngineProcManager, it manages local engines; + if CoreEngineActorManager, it manages all engines. + coordinator: The coordinator for data parallel. + """ + + try: + logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, BaseProcess] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc + + actor_run_refs = [] + if isinstance(engine_manager, CoreEngineProcManager): + for proc in engine_manager.processes: + sentinel_to_proc[proc.sentinel] = proc + elif isinstance(engine_manager, CoreEngineActorManager): + actor_run_refs = engine_manager.get_run_refs() + + # Check if any process terminates + while sentinel_to_proc or actor_run_refs: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, + timeout=5) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + + if actor_run_refs: + import ray + _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) + + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if engine_manager: + engine_manager.close() diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 56d2e3d975d2..7d3de6d18a3f 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,44 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import multiprocessing import time -import weakref from collections import defaultdict from collections.abc import Sequence -from dataclasses import dataclass -from enum import Enum, auto -from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, overload -import msgspec import torch -import zmq -from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, - get_tcp_uri, kill_process_tree) -from vllm.v1.executor.abstract import Executor +from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, + kill_process_tree) if TYPE_CHECKING: - from ray.util.placement_group import PlacementGroup - from vllm.attention.layer import Attention - from vllm.v1.engine.coordinator import DPCoordinator logger = init_logger(__name__) T = TypeVar("T") -STARTUP_POLL_PERIOD_MS = 10000 - class ConstantList(Generic[T], Sequence): @@ -123,526 +106,6 @@ def get_engine_client_zmq_addr(local_only: bool, host, port or get_open_port())) -class CoreEngineState(Enum): - NEW = auto() - CONNECTED = auto() - READY = auto() - - -class CoreEngine: - """One per data parallel rank.""" - - def __init__(self, index: int = 0, local: bool = True): - self.local = local - self.index = index - self.identity = index.to_bytes(2, "little") - - self.state = CoreEngineState.NEW - - -@dataclass -class EngineZmqAddresses: - # ZMQ input socket addresses for each front-end client (requests) - inputs: list[str] - # ZMQ output socket addresses for each front-end client (responses) - outputs: list[str] - # ZMQ input socket address of DP coordinator if applicable - coordinator_input: Optional[str] = None - # ZMQ output socket address of DP coordinator if applicable - coordinator_output: Optional[str] = None - # ZMQ socket for front-end to connect to DP coordinator. - # Not used by engine, just relayed to front-end in handshake response. - # Only required for external DP LB case. - frontend_stats_publish_address: Optional[str] = None - - -@dataclass -class EngineHandshakeMetadata: - """Metadata sent to each engine process during startup handshake, - including addresses of the front-end ZMQ queues that they should - connect to. - """ - addresses: EngineZmqAddresses - parallel_config: dict[str, Union[int, str]] - - -class APIServerProcessManager: - """Manages a group of API server processes. - - Handles creation, monitoring, and termination of API server worker - processes. Also monitors extra processes to check if they are healthy. - """ - - def __init__( - self, - target_server_fn: Callable, - listen_address: str, - sock: Any, - args: argparse.Namespace, - num_servers: int, - input_addresses: list[str], - output_addresses: list[str], - stats_update_address: Optional[str] = None, - ): - """Initialize and start API server worker processes. - - Args: - target_server_fn: Function to call for each API server process - listen_address: Address to listen for client connections - sock: Socket for client connections - args: Command line arguments - num_servers: Number of API server processes to start - input_addresses: Input addresses for each API server - output_addresses: Output addresses for each API server - stats_update_address: Optional stats update address - """ - self.listen_address = listen_address - self.sock = sock - self.args = args - - # Start API servers - spawn_context = multiprocessing.get_context("spawn") - self.processes: list[BaseProcess] = [] - - for i, in_addr, out_addr in zip(range(num_servers), input_addresses, - output_addresses): - client_config = { - "input_address": in_addr, - "output_address": out_addr, - "client_index": i - } - if stats_update_address is not None: - client_config["stats_update_address"] = stats_update_address - - proc = spawn_context.Process(target=target_server_fn, - name=f"ApiServer_{i}", - args=(listen_address, sock, args, - client_config)) - self.processes.append(proc) - proc.start() - - logger.info("Started %d API server processes", len(self.processes)) - - # Shutdown only the API server processes on garbage collection - # The extra processes are managed by their owners - self._finalizer = weakref.finalize(self, shutdown, self.processes) - - def close(self) -> None: - self._finalizer() - - -class CoreEngineProcManager: - """ - Utility class to handle creation, readiness, and shutdown - of background processes used by the AsyncLLM and LLMEngine. - """ - - def __init__( - self, - target_fn: Callable, - local_engine_count: int, - start_index: int, - local_start_index: int, - vllm_config: VllmConfig, - local_client: bool, - handshake_address: str, - executor_class: type[Executor], - log_stats: bool, - client_handshake_address: Optional[str] = None, - ): - context = get_mp_context() - common_kwargs = { - "vllm_config": vllm_config, - "local_client": local_client, - "handshake_address": handshake_address, - "executor_class": executor_class, - "log_stats": log_stats, - } - - if client_handshake_address: - common_kwargs[ - "client_handshake_address"] = client_handshake_address - - self.processes: list[BaseProcess] = [] - for index in range(local_engine_count): - local_index = local_start_index + index - global_index = start_index + index - # Start EngineCore in background process. - self.processes.append( - context.Process(target=target_fn, - name=f"EngineCore_{global_index}", - kwargs=common_kwargs | { - "dp_rank": global_index, - "local_dp_rank": local_index, - })) - - self._finalizer = weakref.finalize(self, shutdown, self.processes) - try: - for proc in self.processes: - proc.start() - finally: - # Kill other procs if not all are running. - if self.finished_procs(): - self.close() - - def close(self): - """Shutdown all procs.""" - self._finalizer() - - def join_first(self): - """Wait for any process to exit.""" - connection.wait(proc.sentinel for proc in self.processes) - - def sentinels(self) -> list: - return [proc.sentinel for proc in self.processes] - - def finished_procs(self) -> dict[str, int]: - """Returns dict of proc name -> exit code for any finished procs.""" - return { - proc.name: proc.exitcode - for proc in self.processes if proc.exitcode is not None - } - - -class CoreEngineActorManager: - """ - Utility class to handle creation, readiness, and shutdown - of core engine Ray actors used by the AsyncLLM and LLMEngine. - - Different from CoreEngineProcManager, this class manages - core engines for both local and remote nodes. - """ - - def __init__( - self, - vllm_config: VllmConfig, - addresses: EngineZmqAddresses, - executor_class: type[Executor], - log_stats: bool, - placement_groups: Optional[list["PlacementGroup"]] = None, - local_dp_ranks: Optional[list[int]] = None, - ): - import copy - - import ray - from ray.util.scheduling_strategies import ( - PlacementGroupSchedulingStrategy) - - from vllm.v1.engine.core import DPEngineCoreActor - - self.local_engine_actors: list[ray.ActorHandle] = [] - self.remote_engine_actors: list[ray.ActorHandle] = [] - dp_size = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local - world_size = vllm_config.parallel_config.world_size - - if ray.is_initialized(): - logger.info( - "Ray is already initialized. Skipping Ray initialization.") - else: - ray.init() - - if placement_groups is not None: - assert local_dp_ranks is not None, ( - "local_dp_ranks must be provided if " - "placement_groups is provided") - assert len(placement_groups) == len(local_dp_ranks), ( - "placement_groups and local_dp_ranks must " - "have the same length") - logger.info("Using provided placement groups") - # TODO(rui): validate passed-in placement groups - self.created_placement_groups = [] - else: - placement_groups, local_dp_ranks = \ - CoreEngineActorManager.create_dp_placement_groups(vllm_config) - self.created_placement_groups = placement_groups - assert len(placement_groups) == dp_size, ( - "Number of placement groups must match data parallel size") - - refs = [] - for index in range(dp_size): - local_index = local_dp_ranks[index] - dp_vllm_config = copy.deepcopy(vllm_config) - pg = placement_groups[index] - dp_vllm_config.parallel_config.placement_group = pg - local_client = index < local_engine_count - actor = ray.remote(DPEngineCoreActor).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=world_size, - )).remote(vllm_config=dp_vllm_config, - executor_class=executor_class, - log_stats=log_stats, - local_client=local_client, - addresses=addresses, - dp_rank=index, - local_dp_rank=local_index) - if local_client: - self.local_engine_actors.append(actor) - else: - self.remote_engine_actors.append(actor) - refs.append(actor.wait_for_init.remote()) - - ray.get(refs) - self.run_refs = [] - for actor in self.local_engine_actors + self.remote_engine_actors: - self.run_refs.append(actor.run.remote()) - - @staticmethod - def create_dp_placement_groups( - vllm_config: VllmConfig - ) -> tuple[list["PlacementGroup"], list[int]]: - - import ray - from ray._private.state import available_resources_per_node - from ray.util.state import list_nodes - - logger.info("Creating placement groups for data parallel") - dp_master_ip = \ - vllm_config.parallel_config.data_parallel_master_ip - dp_size = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local - - nodes = list_nodes() - nodes = sorted(list_nodes(), - key=lambda node: node.node_ip != dp_master_ip) - assert nodes[0].node_ip == dp_master_ip, ( - "The first node must be the head node") - assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( - "There can only be one head node") - - available_resources = available_resources_per_node() - world_size = vllm_config.parallel_config.world_size - placement_groups: list[PlacementGroup] = [] - local_dp_ranks: list[int] = [] - - for node in nodes: - node_ip = node.node_ip - node_resources = available_resources[node.node_id] - # For now, each DP rank can only be assigned to one node - # TODO(rui): support allocating a single DP rank - # to multiple nodes - available_engine_count = int(node_resources["GPU"]) // world_size - if node_ip == dp_master_ip: - assert available_engine_count >= local_engine_count, ( - "Not enough resources to allocate DP ranks " - f"on DP master node {node_ip}") - for i in range(local_engine_count): - bundles = [{ - "GPU": 1.0, - "node:" + dp_master_ip: 0.001 - }] * world_size + [{ - "CPU": 1.0 - }] - pg = ray.util.placement_group( - name=f"dp_rank_{len(placement_groups)}", - strategy="STRICT_PACK", - bundles=bundles, - ) - placement_groups.append(pg) - local_dp_ranks.append(i) - else: - for i in range(available_engine_count): - if len(placement_groups) == dp_size: - break - bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] - pg = ray.util.placement_group( - name=f"dp_rank_{len(placement_groups)}", - strategy="STRICT_PACK", - bundles=bundles, - ) - placement_groups.append(pg) - local_dp_ranks.append(i) - return placement_groups, local_dp_ranks - - def get_run_refs(self): - return self.run_refs - - def close(self): - import ray - for actor in self.local_engine_actors + self.remote_engine_actors: - ray.kill(actor) - for pg in self.created_placement_groups: - ray.util.remove_placement_group(pg) - - -def wait_for_engine_startup( - handshake_socket: zmq.Socket, - addresses: EngineZmqAddresses, - core_engines: list[CoreEngine], - parallel_config: ParallelConfig, - cache_config: CacheConfig, - proc_manager: Optional[CoreEngineProcManager], - coord_process: Optional[Process], -): - # Wait for engine core process(es) to send ready messages. - local_count = parallel_config.data_parallel_size_local - remote_count = len(core_engines) - local_count - # [local, remote] counts - conn_pending, start_pending = [local_count, remote_count], [0, 0] - poller = zmq.Poller() - poller.register(handshake_socket, zmq.POLLIN) - - if proc_manager is not None: - for sentinel in proc_manager.sentinels(): - poller.register(sentinel, zmq.POLLIN) - if coord_process is not None: - poller.register(coord_process.sentinel, zmq.POLLIN) - while any(conn_pending) or any(start_pending): - events = poller.poll(STARTUP_POLL_PERIOD_MS) - if not events: - if any(conn_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) - if any(start_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) - continue - if len(events) > 1 or events[0][0] != handshake_socket: - # One of the local core processes exited. - finished = proc_manager.finished_procs() if proc_manager else {} - if coord_process is not None and coord_process.exitcode is not None: - finished[coord_process.name] = coord_process.exitcode - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") - - # Receive HELLO and READY messages from the input socket. - eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() - eng_index = int.from_bytes(eng_identity, "little") - engine = next((e for e in core_engines if e.identity == eng_identity), - None) - if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") - msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] - if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") - - if status == "HELLO" and engine.state == CoreEngineState.NEW: - - # Send init message with DP config info. - init_message = msgspec.msgpack.encode( - EngineHandshakeMetadata( - addresses=addresses, - parallel_config={ - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": - parallel_config.data_parallel_size, - })) - handshake_socket.send_multipart((eng_identity, init_message), - copy=False) - conn_pending[0 if local else 1] -= 1 - start_pending[0 if local else 1] += 1 - engine.state = CoreEngineState.CONNECTED - elif status == "READY" and (engine.state == CoreEngineState.CONNECTED): - # Setup KV cache config with initialization state from - # engine core process. Sum values from all engines in DP case. - num_gpu_blocks = cache_config.num_gpu_blocks or 0 - num_gpu_blocks += msg["num_gpu_blocks"] - cache_config.num_gpu_blocks = num_gpu_blocks - - # In external DP LB mode, the coordinator address that the - # front-end procs connect to is obtained from rank 0 via - # one of the engine handshakes, and passed to the local - # front-end process in the response from the other. - if addresses.frontend_stats_publish_address is None: - addresses.frontend_stats_publish_address = msg.get( - "dp_stats_address") - - start_pending[0 if local else 1] -= 1 - engine.state = CoreEngineState.READY - else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") - - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) - - -def wait_for_completion_or_failure( - api_server_manager: APIServerProcessManager, - engine_manager: Optional[Union[CoreEngineProcManager, - CoreEngineActorManager]] = None, - coordinator: Optional["DPCoordinator"] = None) -> None: - """Wait for all processes to complete or detect if any fail. - - Raises an exception if any process exits with a non-zero status. - - Args: - api_server_manager: The manager for API servers. - engine_manager: The manager for engine processes. - If CoreEngineProcManager, it manages local engines; - if CoreEngineActorManager, it manages all engines. - coordinator: The coordinator for data parallel. - """ - - try: - logger.info("Waiting for API servers to complete ...") - # Create a mapping of sentinels to their corresponding processes - # for efficient lookup - sentinel_to_proc: dict[Any, BaseProcess] = { - proc.sentinel: proc - for proc in api_server_manager.processes - } - - if coordinator: - sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc - - actor_run_refs = [] - if isinstance(engine_manager, CoreEngineProcManager): - for proc in engine_manager.processes: - sentinel_to_proc[proc.sentinel] = proc - elif isinstance(engine_manager, CoreEngineActorManager): - actor_run_refs = engine_manager.get_run_refs() - - # Check if any process terminates - while sentinel_to_proc or actor_run_refs: - # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, - timeout=5) - - # Process any terminated processes - for sentinel in ready_sentinels: - proc = sentinel_to_proc.pop(sentinel) - - # Check if process exited with error - if proc.exitcode != 0: - raise RuntimeError( - f"Process {proc.name} (PID: {proc.pid}) " - f"died with exit code {proc.exitcode}") - - if actor_run_refs: - import ray - _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) - - except KeyboardInterrupt: - logger.info("Received KeyboardInterrupt, shutting down API servers...") - except Exception as e: - logger.exception("Exception occurred while running API servers: %s", - str(e)) - raise - finally: - logger.info("Terminating remaining processes ...") - api_server_manager.close() - if coordinator: - coordinator.close() - if engine_manager: - engine_manager.close() - - # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. def shutdown(procs: list[BaseProcess]): From 85dfe7e6656e846d2bf63b29ea9e65776d5bb6f9 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 24 Jun 2025 22:13:33 -0700 Subject: [PATCH 03/12] move utils not directly related to engines back to v1/utils.py Signed-off-by: Nick Hill --- vllm/entrypoints/cli/serve.py | 6 +- vllm/v1/engine/utils.py | 140 +------------------------------- vllm/v1/utils.py | 149 +++++++++++++++++++++++++++++++++- 3 files changed, 152 insertions(+), 143 deletions(-) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index ec413748020a..8ff1792edb8f 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -23,11 +23,11 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_tcp_uri from vllm.v1.engine.core import EngineCoreProc -from vllm.v1.engine.utils import (APIServerProcessManager, - CoreEngineProcManager, launch_core_engines, - wait_for_completion_or_failure) +from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus +from vllm.v1.utils import (APIServerProcessManager, + wait_for_completion_or_failure) logger = init_logger(__name__) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 4ab5527bfcb1..c4012419411a 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -1,16 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse import contextlib -import multiprocessing import weakref from collections.abc import Iterator from dataclasses import dataclass from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import msgspec import zmq @@ -72,71 +70,6 @@ class EngineHandshakeMetadata: parallel_config: dict[str, Union[int, str]] -class APIServerProcessManager: - """Manages a group of API server processes. - - Handles creation, monitoring, and termination of API server worker - processes. Also monitors extra processes to check if they are healthy. - """ - - def __init__( - self, - target_server_fn: Callable, - listen_address: str, - sock: Any, - args: argparse.Namespace, - num_servers: int, - input_addresses: list[str], - output_addresses: list[str], - stats_update_address: Optional[str] = None, - ): - """Initialize and start API server worker processes. - - Args: - target_server_fn: Function to call for each API server process - listen_address: Address to listen for client connections - sock: Socket for client connections - args: Command line arguments - num_servers: Number of API server processes to start - input_addresses: Input addresses for each API server - output_addresses: Output addresses for each API server - stats_update_address: Optional stats update address - """ - self.listen_address = listen_address - self.sock = sock - self.args = args - - # Start API servers - spawn_context = multiprocessing.get_context("spawn") - self.processes: list[BaseProcess] = [] - - for i, in_addr, out_addr in zip(range(num_servers), input_addresses, - output_addresses): - client_config = { - "input_address": in_addr, - "output_address": out_addr, - "client_index": i - } - if stats_update_address is not None: - client_config["stats_update_address"] = stats_update_address - - proc = spawn_context.Process(target=target_server_fn, - name=f"ApiServer_{i}", - args=(listen_address, sock, args, - client_config)) - self.processes.append(proc) - proc.start() - - logger.info("Started %d API server processes", len(self.processes)) - - # Shutdown only the API server processes on garbage collection - # The extra processes are managed by their owners - self._finalizer = weakref.finalize(self, shutdown, self.processes) - - def close(self) -> None: - self._finalizer() - - class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown @@ -611,74 +544,3 @@ def wait_for_engine_startup( logger.debug("%s from %s core engine process %s.", status, "local" if local else "remote", eng_index) - - -def wait_for_completion_or_failure( - api_server_manager: APIServerProcessManager, - engine_manager: Optional[Union[CoreEngineProcManager, - CoreEngineActorManager]] = None, - coordinator: Optional["DPCoordinator"] = None) -> None: - """Wait for all processes to complete or detect if any fail. - - Raises an exception if any process exits with a non-zero status. - - Args: - api_server_manager: The manager for API servers. - engine_manager: The manager for engine processes. - If CoreEngineProcManager, it manages local engines; - if CoreEngineActorManager, it manages all engines. - coordinator: The coordinator for data parallel. - """ - - try: - logger.info("Waiting for API servers to complete ...") - # Create a mapping of sentinels to their corresponding processes - # for efficient lookup - sentinel_to_proc: dict[Any, BaseProcess] = { - proc.sentinel: proc - for proc in api_server_manager.processes - } - - if coordinator: - sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc - - actor_run_refs = [] - if isinstance(engine_manager, CoreEngineProcManager): - for proc in engine_manager.processes: - sentinel_to_proc[proc.sentinel] = proc - elif isinstance(engine_manager, CoreEngineActorManager): - actor_run_refs = engine_manager.get_run_refs() - - # Check if any process terminates - while sentinel_to_proc or actor_run_refs: - # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, - timeout=5) - - # Process any terminated processes - for sentinel in ready_sentinels: - proc = sentinel_to_proc.pop(sentinel) - - # Check if process exited with error - if proc.exitcode != 0: - raise RuntimeError( - f"Process {proc.name} (PID: {proc.pid}) " - f"died with exit code {proc.exitcode}") - - if actor_run_refs: - import ray - _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) - - except KeyboardInterrupt: - logger.info("Received KeyboardInterrupt, shutting down API servers...") - except Exception as e: - logger.exception("Exception occurred while running API servers: %s", - str(e)) - raise - finally: - logger.info("Terminating remaining processes ...") - api_server_manager.close() - if coordinator: - coordinator.close() - if engine_manager: - engine_manager.close() diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 7d3de6d18a3f..6b40cf6fd36d 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import multiprocessing import time +import weakref from collections import defaultdict from collections.abc import Sequence +from multiprocessing import connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, overload +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) import torch @@ -17,6 +22,9 @@ if TYPE_CHECKING: from vllm.attention.layer import Attention + from vllm.v1.engine.coordinator import DPCoordinator + from vllm.v1.engine.utils import (CoreEngineActorManager, + CoreEngineProcManager) logger = init_logger(__name__) @@ -106,6 +114,145 @@ def get_engine_client_zmq_addr(local_only: bool, host, port or get_open_port())) +class APIServerProcessManager: + """Manages a group of API server processes. + + Handles creation, monitoring, and termination of API server worker + processes. Also monitors extra processes to check if they are healthy. + """ + + def __init__( + self, + target_server_fn: Callable, + listen_address: str, + sock: Any, + args: argparse.Namespace, + num_servers: int, + input_addresses: list[str], + output_addresses: list[str], + stats_update_address: Optional[str] = None, + ): + """Initialize and start API server worker processes. + + Args: + target_server_fn: Function to call for each API server process + listen_address: Address to listen for client connections + sock: Socket for client connections + args: Command line arguments + num_servers: Number of API server processes to start + input_addresses: Input addresses for each API server + output_addresses: Output addresses for each API server + stats_update_address: Optional stats update address + """ + self.listen_address = listen_address + self.sock = sock + self.args = args + + # Start API servers + spawn_context = multiprocessing.get_context("spawn") + self.processes: list[BaseProcess] = [] + + for i, in_addr, out_addr in zip(range(num_servers), input_addresses, + output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + proc = spawn_context.Process(target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + self.processes.append(proc) + proc.start() + + logger.info("Started %d API server processes", len(self.processes)) + + # Shutdown only the API server processes on garbage collection + # The extra processes are managed by their owners + self._finalizer = weakref.finalize(self, shutdown, self.processes) + + def close(self) -> None: + self._finalizer() + + +def wait_for_completion_or_failure( + api_server_manager: APIServerProcessManager, + engine_manager: Optional[Union["CoreEngineProcManager", + "CoreEngineActorManager"]] = None, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all processes to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + + Args: + api_server_manager: The manager for API servers. + engine_manager: The manager for engine processes. + If CoreEngineProcManager, it manages local engines; + if CoreEngineActorManager, it manages all engines. + coordinator: The coordinator for data parallel. + """ + + from vllm.v1.engine.utils import (CoreEngineActorManager, + CoreEngineProcManager) + + try: + logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, BaseProcess] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc + + actor_run_refs = [] + if isinstance(engine_manager, CoreEngineProcManager): + for proc in engine_manager.processes: + sentinel_to_proc[proc.sentinel] = proc + elif isinstance(engine_manager, CoreEngineActorManager): + actor_run_refs = engine_manager.get_run_refs() + + # Check if any process terminates + while sentinel_to_proc or actor_run_refs: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, + timeout=5) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + + if actor_run_refs: + import ray + _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) + + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if engine_manager: + engine_manager.close() + + # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. def shutdown(procs: list[BaseProcess]): From a73c2c0f925410d60850e3bad2536cdefaba6553 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 26 Jun 2025 13:30:29 -0700 Subject: [PATCH 04/12] fix: ignore coordinator XPUB unsubscribe messages Signed-off-by: Nick Hill --- vllm/v1/engine/coordinator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 3a2fa00e6ae6..b3e7a2e85b80 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -196,7 +196,7 @@ def process_input_socket(self, front_publish_address: str, if publish_front in events: buffer = publish_front.recv() - if buffer == b'\x01': + if buffer in (b'\x01', b'\x00'): # Ignore subscription messages. continue From a195cc5041d499e9bb4b9a2c6ad894bc2ca7f1c0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 26 Jun 2025 09:33:30 -0700 Subject: [PATCH 05/12] add CI test Signed-off-by: Nick Hill --- .buildkite/test-pipeline.yaml | 4 + tests/v1/test_external_lb_dp.py | 311 ++++++++++++++++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 tests/v1/test_external_lb_dp.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1536759c06bd..fccfdbee8d74 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -145,6 +145,7 @@ steps: - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py - tests/v1/test_async_llm_dp.py + - tests/v1/test_external_lb_dp.py - tests/v1/engine/test_engine_core_client.py commands: # test with tp=2 and external_dp=2 @@ -155,6 +156,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py @@ -644,10 +646,12 @@ steps: - vllm/worker/model_runner.py - entrypoints/llm/test_collective_rpc.py - tests/v1/test_async_llm_dp.py + - tests/v1/test_external_lb_dp.py - tests/v1/entrypoints/openai/test_multi_api_servers.py - vllm/v1/engine/ commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/test_external_lb_dp.py new file mode 100644 index 000000000000..c90559627bb1 --- /dev/null +++ b/tests/v1/test_external_lb_dp.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +import threading +import time + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "ibm-research/PowerMoE-3b" + +# Number of data parallel ranks for external LB testing +DP_SIZE = int(os.getenv("DP_SIZE", "2")) +# Default tensor parallell size to use +TP_SIZE = int(os.getenv("TP_SIZE", "1")) + + +class ExternalLBServerManager: + """Manages data parallel vLLM server instances for external + load balancer testing.""" + + def __init__(self, model_name: str, dp_size: int, api_server_count: int, + base_server_args: list, tp_size: int = TP_SIZE): + self.model_name = model_name + self.dp_size = dp_size + self.tp_size = tp_size + self.api_server_count = api_server_count + self.base_server_args = base_server_args + self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.server_threads = [] + + def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: + """Start all server instances for external LB mode.""" + for rank in range(self.dp_size): + # Create server args for this specific rank + server_args = self.base_server_args.copy() + + # Add external LB specific arguments + server_args.extend([ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-rank", + str(rank), + "--data-parallel-size-local", + "1", + "--tensor-parallel-size", + str(self.tp_size), + "--port", + str(8000 + rank), # Different port for each rank + "--api-server-count", + str(self.api_server_count), + ]) + + # Use a thread to start each server to allow parallel initialization + def start_server(r, sargs): + try: + # Start the server + server = RemoteOpenAIServer( + self.model_name, sargs, auto_port=False, + env_dict={"CUDA_VISIBLE_DEVICES": str(r)}) + server.__enter__() + print( + f"Server rank {r} started successfully with " + f"{self.api_server_count} API servers" + ) + self.servers.append((server, sargs)) + except Exception as e: + print(f"Failed to start server rank {r}: {e}") + raise + + thread = threading.Thread(target=start_server, + args=(rank, server_args)) + thread.start() + + self.server_threads.append(thread) + + # Wait for all servers to start + for thread in self.server_threads: + thread.join() + + # Give servers additional time to fully initialize and coordinate + time.sleep(2) + + if len(self.servers) != self.dp_size: + raise Exception("Servers failed to start") + + return self.servers + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop all server instances.""" + while self.servers: + try: + self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + ] + + +@pytest.fixture(scope="module", params=[1 , 4]) +def servers(request, default_server_args): + api_server_count = request.param + with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count, + default_server_args) as server_list: + yield server_list + + +@pytest_asyncio.fixture +async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): + # Create a client for each server + async_clients = [] + for server, _ in servers: + client = await server.get_async_client().__aenter__() + async_clients.append(client) + + try: + yield async_clients + finally: + # Clean up all clients + for client in async_clients: + try: + await client.__aexit__(None, None, None) + except Exception as e: + print(f"Error closing client: {e}") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_external_lb_single_completion( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str) -> None: + + async def make_request(client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=10, + temperature=1.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + # The exact number of tokens can vary slightly with temperature=1.0, + # so we check for a reasonable minimum length. + assert len(choice.text) >= 1 + # Finish reason might not always be 'length' if the model finishes early + # or due to other reasons, especially with high temperature. + # So, we'll accept 'length' or 'stop'. + assert choice.finish_reason in ("length", "stop") + + # Token counts can also vary, so we check they are positive. + assert completion.usage.completion_tokens > 0 + assert completion.usage.prompt_tokens > 0 + assert completion.usage.total_tokens > 0 + return completion + + # Test single request to each server + for i, client in enumerate(clients): + result = await make_request(client) + assert result is not None + print(f"Server {i} handled single completion request successfully") + + await asyncio.sleep(0.5) + + # Send requests to all servers in round-robin fashion + num_requests_per_server = 25 # Total 50 requests across 2 servers + all_tasks = [] + + for i, client in enumerate(clients): + tasks = [make_request(client) for _ in range(num_requests_per_server)] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_server * len(clients) + assert all(completion is not None for completion in results) + + await asyncio.sleep(0.5) + + # Second burst of requests + all_tasks = [] + for i, client in enumerate(clients): + tasks = [make_request(client) for _ in range(num_requests_per_server)] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_server * len(clients) + assert all(completion is not None for completion in results) + + _, server_args = servers[0] + print( + f"Successfully completed external LB test with {len(clients)} servers " + f"(API server count: {server_args.count('--api-server-count') + and server_args[server_args.index( + '--api-server-count') + 1] or '1'})" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_external_lb_completion_streaming(clients: list[ + openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str) -> None: + prompt = "What is an LLM?" + + async def make_streaming_request(client: openai.AsyncOpenAI): + # Perform a non-streaming request to get the expected full output + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + + # Perform the streaming request + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: list[str] = [] + finish_reason_count = 0 + last_chunk = None + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + last_chunk = chunk # Keep track of the last chunk + + # finish reason should only return in the last block for OpenAI API + assert finish_reason_count == 1, ( + "Finish reason should appear exactly once.") + assert last_chunk is not None, ( + "Stream should have yielded at least one chunk.") + assert last_chunk.choices[ + 0].finish_reason == "length", "Finish reason should be 'length'." + # Check that the combined text matches the non-streamed version. + assert "".join( + chunks + ) == single_output, "Streamed output should match non-streamed output." + return True # Indicate success for this request + + # Test single request to each server + for i, client in enumerate(clients): + result = await make_streaming_request(client) + assert result is not None + print(f"Server {i} handled single streaming request successfully") + + await asyncio.sleep(0.5) + + # Send streaming requests to all servers in round-robin fashion + num_requests_per_server = 25 # Total 50 requests across 2 servers + all_tasks = [] + + for i, client in enumerate(clients): + tasks = [ + make_streaming_request(client) + for _ in range(num_requests_per_server) + ] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_server * len(clients) + assert all(results), "Not all streaming requests completed successfully." + + await asyncio.sleep(0.5) + + # Second burst of streaming requests + all_tasks = [] + for i, client in enumerate(clients): + tasks = [ + make_streaming_request(client) + for _ in range(num_requests_per_server) + ] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_server * len(clients) + assert all(results), "Not all streaming requests completed successfully." + + _, server_args = servers[0] + print( + f"Successfully completed external LB streaming test with " + f"(API server count: {server_args.count('--api-server-count') + and server_args[server_args.index( + '--api-server-count') + 1] or '1'})" + ) From 04f2f97044b1db6ca0fea20c55eaa2d6f35289fc Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 26 Jun 2025 15:10:38 -0700 Subject: [PATCH 06/12] fix typing in test Signed-off-by: Nick Hill --- tests/v1/test_external_lb_dp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/test_external_lb_dp.py index c90559627bb1..6551e351e7dc 100644 --- a/tests/v1/test_external_lb_dp.py +++ b/tests/v1/test_external_lb_dp.py @@ -31,7 +31,7 @@ def __init__(self, model_name: str, dp_size: int, api_server_count: int, self.api_server_count = api_server_count self.base_server_args = base_server_args self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] - self.server_threads = [] + self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: """Start all server instances for external LB mode.""" From 2025b49239dd8c99a108075c6e2ebe7b7ef892b6 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 26 Jun 2025 21:13:34 -0700 Subject: [PATCH 07/12] fix test_engine_core_proc_instantiation_cuda_empty Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 79ce5b126db0..65f1da803fb2 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -563,7 +563,7 @@ def create_mock_executor(vllm_config): m.setenv("VLLM_USE_V1", "1") m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices - from vllm.v1.utils import EngineZmqAddresses + from vllm.v1.engine.utils import EngineZmqAddresses def mock_startup_handshake(self, handshake_socket, on_head_node, parallel_config): @@ -580,7 +580,7 @@ def mock_startup_handshake(self, handshake_socket, on_head_node, trust_remote_code=True).create_engine_config() engine_core_proc = EngineCoreProc( vllm_config=vllm_config, - on_head_node=True, + local_client=True, handshake_address="tcp://127.0.0.1:12345", executor_class=mock_executor_class, log_stats=False, From 13de01486b938e816019ad917bc608a8bfdb6461 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 26 Jun 2025 21:38:17 -0700 Subject: [PATCH 08/12] fix test formatting Signed-off-by: Nick Hill --- tests/v1/test_external_lb_dp.py | 66 +++++++++++++++------------------ 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/test_external_lb_dp.py index 6551e351e7dc..b7d944e92e27 100644 --- a/tests/v1/test_external_lb_dp.py +++ b/tests/v1/test_external_lb_dp.py @@ -4,6 +4,7 @@ import os import threading import time +from contextlib import AsyncExitStack import openai # use the official client for correctness check import pytest @@ -23,8 +24,12 @@ class ExternalLBServerManager: """Manages data parallel vLLM server instances for external load balancer testing.""" - def __init__(self, model_name: str, dp_size: int, api_server_count: int, - base_server_args: list, tp_size: int = TP_SIZE): + def __init__(self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + tp_size: int = TP_SIZE): self.model_name = model_name self.dp_size = dp_size self.tp_size = tp_size @@ -60,13 +65,13 @@ def start_server(r, sargs): try: # Start the server server = RemoteOpenAIServer( - self.model_name, sargs, auto_port=False, + self.model_name, + sargs, + auto_port=False, env_dict={"CUDA_VISIBLE_DEVICES": str(r)}) server.__enter__() - print( - f"Server rank {r} started successfully with " - f"{self.api_server_count} API servers" - ) + print(f"Server rank {r} started successfully with " + f"{self.api_server_count} API servers") self.servers.append((server, sargs)) except Exception as e: print(f"Failed to start server rank {r}: {e}") @@ -113,7 +118,7 @@ def default_server_args(): ] -@pytest.fixture(scope="module", params=[1 , 4]) +@pytest.fixture(scope="module", params=[1, 4]) def servers(request, default_server_args): api_server_count = request.param with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count, @@ -124,20 +129,11 @@ def servers(request, default_server_args): @pytest_asyncio.fixture async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): # Create a client for each server - async_clients = [] - for server, _ in servers: - client = await server.get_async_client().__aenter__() - async_clients.append(client) - - try: - yield async_clients - finally: - # Clean up all clients - for client in async_clients: - try: - await client.__aexit__(None, None, None) - except Exception as e: - print(f"Error closing client: {e}") + async with AsyncExitStack() as stack: + yield [ + await stack.enter_context(server.get_async_client()) + for server, _ in servers + ] @pytest.mark.asyncio @@ -145,10 +141,9 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): "model_name", [MODEL_NAME], ) -async def test_external_lb_single_completion( - clients: list[openai.AsyncOpenAI], - servers: list[tuple[RemoteOpenAIServer, list[str]]], - model_name: str) -> None: +async def test_external_lb_single_completion(clients: list[ + openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str) -> None: async def make_request(client: openai.AsyncOpenAI): completion = await client.completions.create( @@ -208,12 +203,12 @@ async def make_request(client: openai.AsyncOpenAI): assert all(completion is not None for completion in results) _, server_args = servers[0] + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) print( f"Successfully completed external LB test with {len(clients)} servers " - f"(API server count: {server_args.count('--api-server-count') - and server_args[server_args.index( - '--api-server-count') + 1] or '1'})" - ) + f"(API server count: {api_server_count})") @pytest.mark.asyncio @@ -303,9 +298,8 @@ async def make_streaming_request(client: openai.AsyncOpenAI): assert all(results), "Not all streaming requests completed successfully." _, server_args = servers[0] - print( - f"Successfully completed external LB streaming test with " - f"(API server count: {server_args.count('--api-server-count') - and server_args[server_args.index( - '--api-server-count') + 1] or '1'})" - ) + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) + print(f"Successfully completed external LB streaming test with " + f"{len(clients)} servers (API server count: {api_server_count})") From 3cc7bd1a5c238ce690a68f31e112912391f5417f Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 1 Jul 2025 16:52:53 +0100 Subject: [PATCH 09/12] test fixes Signed-off-by: Nick Hill --- tests/v1/test_external_lb_dp.py | 12 +++++++++--- vllm/v1/engine/core.py | 15 +++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/test_external_lb_dp.py index b7d944e92e27..07a0a9ad13b5 100644 --- a/tests/v1/test_external_lb_dp.py +++ b/tests/v1/test_external_lb_dp.py @@ -9,6 +9,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio +from platforms import Platform from tests.utils import RemoteOpenAIServer @@ -61,14 +62,19 @@ def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: ]) # Use a thread to start each server to allow parallel initialization - def start_server(r, sargs): + def start_server(r: int, sargs: list[str]): try: # Start the server server = RemoteOpenAIServer( self.model_name, sargs, auto_port=False, - env_dict={"CUDA_VISIBLE_DEVICES": str(r)}) + env_dict={ + "CUDA_VISIBLE_DEVICES": + ",".join( + str(Platform.device_id_to_physical_device_id( + i)) for i in range(DP_SIZE * TP_SIZE)) + }) server.__enter__() print(f"Server rank {r} started successfully with " f"{self.api_server_count} API servers") @@ -131,7 +137,7 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): # Create a client for each server async with AsyncExitStack() as stack: yield [ - await stack.enter_context(server.get_async_client()) + await stack.enter_async_context(server.get_async_client()) for server, _ in servers ] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d89e6ae7433c..e2fdf6f8a11c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -868,10 +868,17 @@ def _init_data_parallel(self, vllm_config: VllmConfig): device_control_env_var = current_platform.device_control_env_var world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * - world_size)) + try: + os.environ[device_control_env_var] = ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * + world_size, (local_dp_rank + 1) * world_size)) + except IndexError as e: + raise Exception( + f"Error setting {device_control_env_var}: " + f"local range: [{local_dp_rank * world_size}, " + f"{(local_dp_rank + 1) * world_size}) " + f"base value: \"{os.getenv(device_control_env_var)}\"") from e self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() From 410ee4da171d7039bc99fef046fe07e4da4d7d3b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 1 Jul 2025 17:59:45 +0100 Subject: [PATCH 10/12] fix test import Signed-off-by: Nick Hill --- .buildkite/test-pipeline.yaml | 2 +- tests/v1/test_external_lb_dp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7310a5ed1f3c..175269e857e0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -164,7 +164,7 @@ steps: # test with tp=2 and pp=2 - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - - python3 ../examples/offline_inference/data_parallel.py + - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/test_external_lb_dp.py index 07a0a9ad13b5..dcc7c5c4fce9 100644 --- a/tests/v1/test_external_lb_dp.py +++ b/tests/v1/test_external_lb_dp.py @@ -9,9 +9,9 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio -from platforms import Platform from tests.utils import RemoteOpenAIServer +from vllm.platforms import Platform MODEL_NAME = "ibm-research/PowerMoE-3b" From d93b903ff5573e5f8928502875def4e60154a106 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 1 Jul 2025 19:23:54 +0100 Subject: [PATCH 11/12] fix test properly hopefully Signed-off-by: Nick Hill --- tests/v1/test_external_lb_dp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/test_external_lb_dp.py index dcc7c5c4fce9..17952dfb0d91 100644 --- a/tests/v1/test_external_lb_dp.py +++ b/tests/v1/test_external_lb_dp.py @@ -73,7 +73,8 @@ def start_server(r: int, sargs: list[str]): "CUDA_VISIBLE_DEVICES": ",".join( str(Platform.device_id_to_physical_device_id( - i)) for i in range(DP_SIZE * TP_SIZE)) + i)) + for i in range(r * TP_SIZE, (r + 1) * TP_SIZE)) }) server.__enter__() print(f"Server rank {r} started successfully with " From d0b5450403cddb447cde026a281f3bae511ddb1d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 2 Jul 2025 09:04:17 +0100 Subject: [PATCH 12/12] validate mutually exclusive command line args Signed-off-by: Nick Hill --- vllm/entrypoints/cli/serve.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 8ff1792edb8f..2ca31510208d 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -44,11 +44,15 @@ def cmd(args: argparse.Namespace) -> None: if args.headless or args.api_server_count < 1: run_headless(args) - elif args.api_server_count > 1: - run_multi_api_server(args) else: - # Single API server (this process). - uvloop.run(run_server(args)) + if args.data_parallel_start_rank: + raise ValueError("data_parallel_start_rank is only " + "applicable in headless mode") + if args.api_server_count > 1: + run_multi_api_server(args) + else: + # Single API server (this process). + uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: validate_parsed_serve_args(args) @@ -117,14 +121,19 @@ def run_headless(args: argparse.Namespace): parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local - host = parallel_config.data_parallel_master_ip - port = engine_args.data_parallel_rpc_port # add to config too - handshake_address = get_tcp_uri(host, port) if local_engine_count <= 0: raise ValueError("data_parallel_size_local must be > 0 in " "headless mode") + if parallel_config.data_parallel_rank is not None: + raise ValueError("data_parallel_rank is not applicable in " + "headless mode") + + host = parallel_config.data_parallel_master_ip + port = engine_args.data_parallel_rpc_port # add to config too + handshake_address = get_tcp_uri(host, port) + # Catch SIGTERM and SIGINT to allow graceful shutdown. def signal_handler(signum, frame): logger.debug("Received %d signal.", signum)