Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit f7db9ea

Browse files
youkaichaoandy-neuma
authored andcommitted
[Core] separate distributed_init from worker (vllm-project#3904)
1 parent 88385bb commit f7db9ea

File tree

4 files changed

+85
-58
lines changed

4 files changed

+85
-58
lines changed

vllm/model_executor/parallel_utils/parallel_state.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
55
"""Tensor and pipeline parallel groups."""
66
import contextlib
7+
from typing import Optional
78

89
import torch
910

@@ -14,14 +15,59 @@
1415
# Pipeline model parallel group that the current rank belongs to.
1516
_PIPELINE_MODEL_PARALLEL_GROUP = None
1617

18+
# when people blindly call `torch.distributed.all_reduce` etc,
19+
# it will use this group. It is initialized with the `backend`
20+
# parameter of `init_distributed_environment` below.
21+
# Essentially, this is `torch.distributed.group.WORLD`.
22+
# We leave a line here to note that this is device-specific.
23+
# Note that this variable is not safe to use, because when users
24+
# call `init_distributed_environment` first, and then destroy
25+
# the process group themselves, this variable will keep a reference to the
26+
# destroyed process group, which is not useful.
27+
_DEVICE_WORLD_GROUP = None
28+
29+
# duing `init_distributed_environment`, we will also initialize a
30+
# group with `gloo` backend, to allow direct coordination between
31+
# processes through the CPU.
32+
_CPU_WORLD_GROUP = None
33+
34+
# In summary, after calling `init_distributed_environment`, we will
35+
# always have two groups: one for device-specific (and is the default)
36+
# and one for CPU. All processes will be part of both groups.
37+
1738
# A list of global ranks for each pipeline group to ease calculation of the
1839
# source rank when broadcasting from the first or last pipeline stage.
1940
_PIPELINE_GLOBAL_RANKS = None
2041

2142

43+
def init_distributed_environment(
44+
world_size: int,
45+
rank: int,
46+
distributed_init_method: Optional[str] = None,
47+
local_rank: int = -1,
48+
backend: str = "nccl",
49+
):
50+
if not torch.distributed.is_initialized():
51+
assert distributed_init_method is not None, (
52+
"distributed_init_method must be provided when initializing "
53+
"distributed environment")
54+
# this backend is used for WORLD
55+
torch.distributed.init_process_group(
56+
backend=backend,
57+
init_method=distributed_init_method,
58+
world_size=world_size,
59+
rank=rank)
60+
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
61+
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
62+
ranks = list(range(torch.distributed.get_world_size()))
63+
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
64+
backend="gloo")
65+
66+
2267
def initialize_model_parallel(
2368
tensor_model_parallel_size: int = 1,
2469
pipeline_model_parallel_size: int = 1,
70+
backend: Optional[str] = None,
2571
) -> None:
2672
"""
2773
Initialize model parallel groups.
@@ -48,6 +94,8 @@ def initialize_model_parallel(
4894
# Get world size and rank. Ensure some consistencies.
4995
assert torch.distributed.is_initialized()
5096
world_size: int = torch.distributed.get_world_size()
97+
# get the backend of _DEVICE_WORLD_GROUP
98+
backend = backend or torch.distributed.get_backend()
5199

52100
if (world_size !=
53101
tensor_model_parallel_size * pipeline_model_parallel_size):
@@ -69,7 +117,7 @@ def initialize_model_parallel(
69117
for i in range(num_tensor_model_parallel_groups):
70118
ranks = range(i * tensor_model_parallel_size,
71119
(i + 1) * tensor_model_parallel_size)
72-
group = torch.distributed.new_group(ranks)
120+
group = torch.distributed.new_group(ranks, backend=backend)
73121
if rank in ranks:
74122
_TENSOR_MODEL_PARALLEL_GROUP = group
75123

@@ -80,7 +128,7 @@ def initialize_model_parallel(
80128
"pipeline model parallel group is already initialized")
81129
for i in range(num_pipeline_model_parallel_groups):
82130
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
83-
group = torch.distributed.new_group(ranks)
131+
group = torch.distributed.new_group(ranks, backend=backend)
84132
if rank in ranks:
85133
_PIPELINE_MODEL_PARALLEL_GROUP = group
86134
_PIPELINE_GLOBAL_RANKS = ranks
@@ -89,14 +137,17 @@ def initialize_model_parallel(
89137
def ensure_model_parallel_initialized(
90138
tensor_model_parallel_size: int,
91139
pipeline_model_parallel_size: int,
140+
backend: Optional[str] = None,
92141
) -> None:
93142
"""Helper to initialize model parallel groups if they are not initialized,
94143
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
95144
values if the model parallel groups are initialized.
96145
"""
146+
# get the backend of _DEVICE_WORLD_GROUP
147+
backend = backend or torch.distributed.get_backend()
97148
if not model_parallel_is_initialized():
98149
initialize_model_parallel(tensor_model_parallel_size,
99-
pipeline_model_parallel_size)
150+
pipeline_model_parallel_size, backend)
100151
return
101152

102153
assert (
@@ -117,6 +168,12 @@ def model_parallel_is_initialized():
117168
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
118169

119170

171+
def get_cpu_world_group():
172+
"""Get the CPU world group."""
173+
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
174+
return _CPU_WORLD_GROUP
175+
176+
120177
def get_tensor_model_parallel_group():
121178
"""Get the tensor model parallel group the caller rank belongs to."""
122179
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (

vllm/test_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import ray
22

3-
from vllm.config import ParallelConfig
3+
from vllm.model_executor.parallel_utils.parallel_state import (
4+
ensure_model_parallel_initialized, init_distributed_environment)
45
from vllm.utils import get_open_port
5-
from vllm.worker.worker import init_distributed_environment
66

77

88
def init_test_distributed_environment(
@@ -12,15 +12,14 @@ def init_test_distributed_environment(
1212
distributed_init_port: str,
1313
local_rank: int = -1,
1414
) -> None:
15-
parallel_config = ParallelConfig(pipeline_parallel_size,
16-
tensor_parallel_size,
17-
worker_use_ray=True)
1815
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
1916
init_distributed_environment(
20-
parallel_config,
21-
rank,
17+
world_size=pipeline_parallel_size * tensor_parallel_size,
18+
rank=rank,
2219
distributed_init_method=distributed_init_method,
2320
local_rank=local_rank)
21+
ensure_model_parallel_initialized(tensor_parallel_size,
22+
pipeline_parallel_size)
2423

2524

2625
def multi_process_tensor_parallel(

vllm/worker/cpu_worker.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.model_executor.parallel_utils.communication_op import (
1414
broadcast_tensor_dict)
1515
from vllm.model_executor.parallel_utils.parallel_state import (
16-
ensure_model_parallel_initialized)
16+
ensure_model_parallel_initialized, init_distributed_environment)
1717
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
1818
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
1919
from vllm.worker.model_runner import ModelRunner
@@ -251,26 +251,12 @@ def init_distributed_environment(self) -> None:
251251
parallel_config = self.parallel_config
252252
rank = self.rank
253253
distributed_init_method = self.distributed_init_method
254-
255-
if torch.distributed.is_initialized():
256-
torch_world_size = torch.distributed.get_world_size()
257-
if torch_world_size != parallel_config.world_size:
258-
raise RuntimeError(
259-
"torch.distributed is already initialized but the torch "
260-
"world size does not match parallel_config.world_size "
261-
f"({torch_world_size} vs. {parallel_config.world_size}).")
262-
elif not distributed_init_method:
263-
raise ValueError(
264-
"distributed_init_method must be set if torch.distributed "
265-
"is not already initialized")
266-
else:
267-
backend = "gloo"
268-
torch.distributed.init_process_group(
269-
backend=backend,
270-
world_size=parallel_config.world_size,
271-
rank=rank,
272-
init_method=distributed_init_method,
273-
)
254+
init_distributed_environment(
255+
world_size=parallel_config.world_size,
256+
rank=rank,
257+
distributed_init_method=distributed_init_method,
258+
backend="gloo",
259+
)
274260

275261
# A small all_reduce for warmup.
276262
torch.distributed.all_reduce(torch.zeros(1).cpu())

vllm/worker/worker.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
broadcast_tensor_dict)
1616
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
1717
from vllm.model_executor.parallel_utils.parallel_state import (
18-
ensure_model_parallel_initialized)
18+
ensure_model_parallel_initialized, init_distributed_environment)
1919
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
2020
from vllm.worker.cache_engine import CacheEngine
2121
from vllm.worker.model_runner import ModelRunner
@@ -97,9 +97,9 @@ def init_device(self) -> None:
9797
raise RuntimeError(
9898
f"Not support device type: {self.device_config.device}")
9999
# Initialize the distributed environment.
100-
init_distributed_environment(self.parallel_config, self.rank,
101-
self.distributed_init_method,
102-
self.local_rank)
100+
init_worker_distributed_environment(self.parallel_config, self.rank,
101+
self.distributed_init_method,
102+
self.local_rank)
103103
# Set random seed.
104104
set_random_seed(self.model_config.seed)
105105

@@ -248,31 +248,15 @@ def get_cache_block_size_bytes(self, block_size: int,
248248
self.parallel_config)
249249

250250

251-
def init_distributed_environment(
251+
def init_worker_distributed_environment(
252252
parallel_config: ParallelConfig,
253253
rank: int,
254254
distributed_init_method: Optional[str] = None,
255255
local_rank: int = -1,
256256
) -> None:
257257
"""Initialize the distributed environment."""
258-
if torch.distributed.is_initialized():
259-
torch_world_size = torch.distributed.get_world_size()
260-
if torch_world_size != parallel_config.world_size:
261-
raise RuntimeError(
262-
"torch.distributed is already initialized but the torch world "
263-
"size does not match parallel_config.world_size "
264-
f"({torch_world_size} vs. {parallel_config.world_size}).")
265-
elif not distributed_init_method:
266-
raise ValueError(
267-
"distributed_init_method must be set if torch.distributed "
268-
"is not already initialized")
269-
else:
270-
torch.distributed.init_process_group(
271-
backend="nccl",
272-
world_size=parallel_config.world_size,
273-
rank=rank,
274-
init_method=distributed_init_method,
275-
)
258+
init_distributed_environment(parallel_config.world_size, rank,
259+
distributed_init_method, local_rank)
276260

277261
if pynccl_utils.is_initialized():
278262
pynccl_world_size = pynccl_utils.get_world_size()
@@ -291,17 +275,18 @@ def init_distributed_environment(
291275
init_method=distributed_init_method,
292276
)
293277

294-
# A small all_reduce for warmup.
295-
torch.distributed.all_reduce(torch.zeros(1).cuda())
296-
if pynccl_utils.is_initialized():
297-
pynccl_utils.all_reduce(torch.zeros(1).cuda())
298278
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
299279
parallel_config.pipeline_parallel_size)
300280

301281
# Initialize a custom fast all-reduce implementation.
302282
if not parallel_config.disable_custom_all_reduce:
303283
init_custom_ar()
304284

285+
# A small all_reduce for warmup.
286+
torch.distributed.all_reduce(torch.zeros(1).cuda())
287+
if pynccl_utils.is_initialized():
288+
pynccl_utils.all_reduce(torch.zeros(1).cuda())
289+
305290

306291
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
307292
# Check if the GPU supports the dtype.

0 commit comments

Comments
 (0)