Skip to content

Commit b316ac6

Browse files
luccafongnjhill
andauthored
[V1] Support MP Executor for multi node distributed inference (#23691)
Signed-off-by: Lu Fang <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Lucia Fang <[email protected]> Signed-off-by: Lucia Fang <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent a55b646 commit b316ac6

File tree

10 files changed

+930
-82
lines changed

10 files changed

+930
-82
lines changed

tests/distributed/test_multiproc_executor.py

Lines changed: 437 additions & 0 deletions
Large diffs are not rendered by default.

vllm/config/parallel.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,18 @@ class ParallelConfig:
210210
class is dynamically inherited by the worker class. This is used to inject
211211
new attributes and methods to the worker class for use in collective_rpc
212212
calls."""
213+
master_addr: str = "127.0.0.1"
214+
"""distributed master address for multi-node distributed
215+
inference when distributed_executor_backend is mp."""
216+
master_port: int = 29501
217+
"""distributed master port for multi-node distributed
218+
inference when distributed_executor_backend is mp."""
219+
node_rank: int = 0
220+
"""distributed node rank for multi-node distributed
221+
inference when distributed_executor_backend is mp."""
222+
nnodes: int = 1
223+
"""num of nodes for multi-node distributed
224+
inference when distributed_executor_backend is mp."""
213225

214226
world_size: int = Field(init=False)
215227
"""world_size is TPxPP, it affects the number of workers we create."""
@@ -387,6 +399,23 @@ def use_sequence_parallel_moe(self) -> bool:
387399
and self.data_parallel_size > 1
388400
)
389401

402+
@property
403+
def node_rank_within_dp(self) -> int:
404+
return self.node_rank % self.nnodes_within_dp
405+
406+
@property
407+
def nnodes_within_dp(self) -> int:
408+
if self.nnodes == 1:
409+
return 1
410+
data_parallel_node_size = (
411+
self.data_parallel_size // self.data_parallel_size_local
412+
)
413+
return self.nnodes // data_parallel_node_size
414+
415+
@property
416+
def local_world_size(self) -> int:
417+
return self.world_size // self.nnodes_within_dp
418+
390419
@staticmethod
391420
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
392421
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
@@ -528,6 +557,8 @@ def __post_init__(self) -> None:
528557
ray_found = ray_utils.ray_is_available()
529558
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
530559
backend = "uni"
560+
elif current_platform.is_cuda() and self.nnodes > 1:
561+
backend = "mp"
531562
elif (
532563
current_platform.is_cuda()
533564
and cuda_device_count_stateless() < self.world_size
@@ -565,6 +596,10 @@ def __post_init__(self) -> None:
565596
"max_parallel_loading_workers is currently "
566597
"not supported and will be ignored."
567598
)
599+
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
600+
raise ValueError(
601+
"nnodes > 1 can only be set when distributed exectuor backend is mp."
602+
)
568603

569604
@property
570605
def use_ray(self) -> bool:
@@ -607,6 +642,11 @@ def _verify_args(self) -> Self:
607642
"Disabled the custom all-reduce kernel because it is not "
608643
"supported on current platform."
609644
)
645+
if self.nnodes > 1:
646+
self.disable_custom_all_reduce = True
647+
logger.debug(
648+
"Disabled the custom all-reduce since we are running on multi-node."
649+
)
610650
if self.ray_workers_use_nsight and not self.use_ray:
611651
raise ValueError(
612652
"Unable to use nsight profiling unless workers run with Ray."

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from multiprocessing import shared_memory
99
from pickle import PickleBuffer
1010
from threading import Event
11-
from typing import TYPE_CHECKING, Any
11+
from typing import TYPE_CHECKING, Any, cast
1212
from unittest.mock import patch
1313

1414
import torch
@@ -602,13 +602,87 @@ def broadcast_object(self, obj=None):
602602
return obj
603603
return self.dequeue()
604604

605+
@staticmethod
606+
def create_from_process_group_single_reader(
607+
pg: ProcessGroup,
608+
max_chunk_bytes,
609+
max_chunks,
610+
reader_rank: int = 0,
611+
blocking: bool = False,
612+
) -> tuple["MessageQueue", list[Handle]]:
613+
"""
614+
Creates a MessageQueue for a process group with a single reader.
615+
616+
This method is designed for scenarios where only one process (the reader)
617+
will consume messages, and all other processes are writers. It sets up
618+
the shared memory buffer and communication handles accordingly, and
619+
gathers the handles from all processes to the reader.
620+
621+
Args:
622+
pg (ProcessGroup): The torch distributed process group.
623+
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
624+
max_chunks (int): Maximum number of chunks in the buffer.
625+
reader_rank (int, optional): The global rank that will act as the reader.
626+
Defaults to 0.
627+
blocking (bool, optional): If True, blocks until all processes are ready.
628+
Defaults to False.
629+
630+
Returns:
631+
tuple[MessageQueue, list[Handle]]:
632+
The MessageQueue instance for the calling process,
633+
and a list of handles (only non-empty for the reader process).
634+
"""
635+
local_size = torch.cuda.device_count()
636+
rank = dist.get_rank()
637+
same_node = rank // local_size == reader_rank // local_size
638+
buffer_io = MessageQueue(
639+
n_reader=1,
640+
n_local_reader=1 if same_node else 0,
641+
max_chunk_bytes=max_chunk_bytes,
642+
max_chunks=max_chunks,
643+
)
644+
handle = buffer_io.export_handle()
645+
handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
646+
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
647+
if blocking:
648+
buffer_io.wait_until_ready()
649+
return buffer_io, cast(list[Handle], handles or [])
650+
605651
@staticmethod
606652
def create_from_process_group(
607653
pg: ProcessGroup | StatelessProcessGroup,
608654
max_chunk_bytes,
609655
max_chunks,
610-
writer_rank=0,
656+
writer_rank: int = 0,
657+
external_writer_handle=None,
658+
blocking: bool = True,
611659
) -> "MessageQueue":
660+
"""
661+
Creates a MessageQueue for a distributed process group with one writer and
662+
multiple readers.
663+
664+
This method is designed for scenarios where one process (the writer) sends
665+
messages, and all other processes (the readers) receive messages. It sets up
666+
the shared memory buffer and socket communication handles accordingly, and
667+
broadcasts the handle from the writer to all readers.
668+
669+
Args:
670+
pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
671+
group.
672+
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
673+
max_chunks (int): Maximum number of chunks in the buffer.
674+
writer_rank (int, optional): The global rank that will act as the writer.
675+
Defaults to 0.
676+
external_writer_handle (Handle, optional): Used when there is a handle
677+
from an external Message Queue. If provided, use this handle to init
678+
PG writer message queue instead of creating a new one. Defaults to None.
679+
blocking (bool, optional): If True, blocks until all processes are ready.
680+
Defaults to True.
681+
682+
Returns:
683+
MessageQueue: The MessageQueue instance for the calling process.
684+
685+
"""
612686
if isinstance(pg, ProcessGroup):
613687
group_rank = dist.get_rank(pg)
614688
group_world_size = dist.get_world_size(pg)
@@ -617,23 +691,26 @@ def create_from_process_group(
617691
group_rank = pg.rank
618692
group_world_size = pg.world_size
619693
global_ranks = list(range(pg.world_size))
620-
621694
from vllm.distributed.parallel_state import in_the_same_node_as
622695

623696
status = in_the_same_node_as(pg, source_rank=writer_rank)
624-
same_node_ranks = [i for i, s in enumerate(status) if s]
625-
n_reader = group_world_size - 1
626-
n_local_reader = len(same_node_ranks) - 1
627-
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
628-
buffer_io: MessageQueue
629697
if group_rank == writer_rank:
630-
buffer_io = MessageQueue(
631-
n_reader=n_reader,
632-
n_local_reader=n_local_reader,
633-
local_reader_ranks=local_reader_ranks,
634-
max_chunk_bytes=max_chunk_bytes,
635-
max_chunks=max_chunks,
636-
)
698+
if external_writer_handle is not None:
699+
buffer_io = MessageQueue.create_from_handle(
700+
external_writer_handle, group_rank
701+
)
702+
else:
703+
same_node_ranks = [i for i, s in enumerate(status) if s]
704+
n_reader = group_world_size - 1
705+
n_local_reader = len(same_node_ranks) - 1
706+
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
707+
buffer_io = MessageQueue(
708+
n_reader=n_reader,
709+
n_local_reader=n_local_reader,
710+
local_reader_ranks=local_reader_ranks,
711+
max_chunk_bytes=max_chunk_bytes,
712+
max_chunks=max_chunks,
713+
)
637714
handle = buffer_io.export_handle()
638715
if isinstance(pg, ProcessGroup):
639716
dist.broadcast_object_list(
@@ -651,5 +728,6 @@ def create_from_process_group(
651728
else:
652729
handle = pg.broadcast_obj(None, writer_rank)
653730
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
654-
buffer_io.wait_until_ready()
731+
if blocking:
732+
buffer_io.wait_until_ready()
655733
return buffer_io

vllm/distributed/parallel_state.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,33 @@ def __init__(
385385
torch.ops._C, "init_shm_manager"
386386
)
387387

388+
def create_mq_broadcaster(
389+
self, writer_rank=0, external_writer_handle=None, blocking=True
390+
):
391+
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
392+
393+
return MessageQueue.create_from_process_group(
394+
self.cpu_group,
395+
1 << 22,
396+
6,
397+
writer_rank=writer_rank,
398+
external_writer_handle=external_writer_handle,
399+
blocking=blocking,
400+
)
401+
402+
def create_single_reader_mq_broadcasters(
403+
self, reader_rank_in_group=0, blocking=False
404+
):
405+
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
406+
407+
return MessageQueue.create_from_process_group_single_reader(
408+
self.cpu_group,
409+
1 << 22,
410+
6,
411+
reader_rank=self.ranks[reader_rank_in_group],
412+
blocking=blocking,
413+
)
414+
388415
@property
389416
def first_rank(self):
390417
"""Return the global rank of the first process in the group"""
@@ -997,6 +1024,7 @@ def combine(
9971024

9981025

9991026
_WORLD: GroupCoordinator | None = None
1027+
_INNER_DP_WORLD: GroupCoordinator | None = None
10001028
_NODE_COUNT: int | None = None
10011029

10021030

@@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator:
10051033
return _WORLD
10061034

10071035

1036+
def get_inner_dp_world_group() -> GroupCoordinator:
1037+
assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
1038+
return _INNER_DP_WORLD
1039+
1040+
10081041
def init_world_group(
10091042
ranks: list[int], local_rank: int, backend: str
10101043
) -> GroupCoordinator:
@@ -1023,12 +1056,13 @@ def init_model_parallel_group(
10231056
backend: str,
10241057
use_message_queue_broadcaster: bool = False,
10251058
group_name: str | None = None,
1059+
use_device_communicator: bool = True,
10261060
) -> GroupCoordinator:
10271061
return GroupCoordinator(
10281062
group_ranks=group_ranks,
10291063
local_rank=local_rank,
10301064
torch_distributed_backend=backend,
1031-
use_device_communicator=True,
1065+
use_device_communicator=use_device_communicator,
10321066
use_message_queue_broadcaster=use_message_queue_broadcaster,
10331067
group_name=group_name,
10341068
)
@@ -1143,7 +1177,14 @@ def init_distributed_environment(
11431177
from vllm.config import get_current_vllm_config
11441178

11451179
config = get_current_vllm_config()
1146-
if (
1180+
if config is not None and config.parallel_config.nnodes > 1:
1181+
parallel_config = config.parallel_config
1182+
ip = parallel_config.master_addr
1183+
rank = parallel_config.data_parallel_rank * world_size + rank
1184+
world_size = parallel_config.world_size_across_dp
1185+
port = parallel_config.master_port
1186+
distributed_init_method = get_distributed_init_method(ip, port)
1187+
elif (
11471188
config is not None
11481189
and config.parallel_config.data_parallel_size > 1
11491190
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1164,6 +1205,14 @@ def init_distributed_environment(
11641205
distributed_init_method,
11651206
)
11661207
if not torch.distributed.is_initialized():
1208+
logger.info(
1209+
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
1210+
world_size,
1211+
rank,
1212+
local_rank,
1213+
distributed_init_method,
1214+
backend,
1215+
)
11671216
assert distributed_init_method is not None, (
11681217
"distributed_init_method must be provided when initializing "
11691218
"distributed environment"
@@ -1192,16 +1241,36 @@ def init_distributed_environment(
11921241
# local rank not set, this usually happens in single-node
11931242
# setting, where we can use rank as local rank
11941243
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1195-
global _WORLD, _NODE_COUNT
1244+
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
11961245
if _WORLD is None:
11971246
ranks = list(range(torch.distributed.get_world_size()))
11981247
_WORLD = init_world_group(ranks, local_rank, backend)
1199-
_NODE_COUNT = _node_count(_WORLD.cpu_group)
1248+
if config.parallel_config.nnodes > 1:
1249+
_NODE_COUNT = config.parallel_config.nnodes
1250+
else:
1251+
_NODE_COUNT = _node_count(_WORLD.cpu_group)
12001252
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
12011253
else:
12021254
assert _WORLD.world_size == torch.distributed.get_world_size(), (
12031255
"world group already initialized with a different world size"
12041256
)
1257+
if config.parallel_config.nnodes_within_dp > 1:
1258+
if parallel_config.data_parallel_size > 1:
1259+
world_size_inner_dp = parallel_config.world_size
1260+
group_ranks = [
1261+
[dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
1262+
for dp_rank in range(parallel_config.data_parallel_size)
1263+
]
1264+
_INNER_DP_WORLD = init_model_parallel_group(
1265+
group_ranks,
1266+
get_world_group().local_rank,
1267+
backend,
1268+
use_message_queue_broadcaster=True,
1269+
group_name="inner_dp_world",
1270+
use_device_communicator=False,
1271+
)
1272+
else:
1273+
_INNER_DP_WORLD = _WORLD
12051274

12061275

12071276
def initialize_model_parallel(

0 commit comments

Comments
 (0)