Skip to content

Commit 936aba0

Browse files
committed
eep phase: stateless group + CUDA graph support
support request serving during scaling up/down Signed-off-by: Yongji Wu <[email protected]>
1 parent 577d498 commit 936aba0

26 files changed

+2451
-568
lines changed

tools/ep_kernels/elastic_ep/install_eep_libraries.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ if [ -z "$CUDA_HOME" ]; then
5252
exit 1
5353
fi
5454

55+
# assume TORCH_CUDA_ARCH_LIST is set correctly
56+
if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then
57+
echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture."
58+
exit 1
59+
fi
60+
5561
# disable all features except IBGDA
5662
export NVSHMEM_IBGDA_SUPPORT=1
5763

@@ -82,5 +88,6 @@ git clone https://github.com/ppl-ai/pplx-kernels
8288
cd pplx-kernels
8389
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
8490
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
85-
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v
91+
git checkout 12cecfd
92+
PIP_NO_BUILD_ISOLATION=0 pip install . --no-deps -v
8693

vllm/config/parallel.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class ParallelConfig:
157157
disable_custom_all_reduce: bool = False
158158
"""Disable the custom all-reduce kernel and fall back to NCCL."""
159159

160+
enable_elastic_ep: bool = False
161+
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
162+
160163
enable_dbo: bool = False
161164
"""Enable dual batch overlap for the model executor."""
162165

@@ -218,6 +221,21 @@ class is dynamically inherited by the worker class. This is used to inject
218221
Set to be private as it's not intended to be configured by users.
219222
"""
220223

224+
_stateless_world_group_port_list: list[int] = Field(default_factory=list)
225+
"""List of open ports for stateless world group when enable_elastic_ep is True.
226+
Set to be private as it's not intended to be configured by users.
227+
"""
228+
229+
_stateless_dp_group_port_list: list[int] = Field(default_factory=list)
230+
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
231+
Set to be private as it's not intended to be configured by users.
232+
"""
233+
234+
_stateless_ep_group_port_list: list[int] = Field(default_factory=list)
235+
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
236+
Set to be private as it's not intended to be configured by users.
237+
"""
238+
221239
decode_context_parallel_size: int = 1
222240
"""Number of decode context parallel groups, because the world size does
223241
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
@@ -310,7 +328,16 @@ def get_next_dp_init_port(self) -> int:
310328

311329
return answer
312330

313-
def stateless_init_dp_group(self) -> ProcessGroup:
331+
def get_next_stateless_world_group_port(self) -> list[int]:
332+
return self._stateless_world_group_port_list.pop(0)
333+
334+
def get_next_stateless_dp_group_port(self) -> list[int]:
335+
return self._stateless_dp_group_port_list.pop(0)
336+
337+
def get_next_stateless_ep_group_port(self) -> list[int]:
338+
return self._stateless_ep_group_port_list.pop(0)
339+
340+
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
314341
# NOTE: In high-concurrency scenarios multiple processes
315342
# can pick the same (currently free) port through a race
316343
# condition when calling `get_open_port()`. When the first
@@ -335,6 +362,7 @@ def stateless_init_dp_group(self) -> ProcessGroup:
335362
self.data_parallel_rank,
336363
self.data_parallel_size,
337364
backend="gloo",
365+
return_store=return_store,
338366
)
339367
except DistNetworkError as e:
340368
# We only want to retry when the root cause is EADDRINUSE.
@@ -470,6 +498,36 @@ def __post_init__(self) -> None:
470498
logger.info("Using external launcher for distributed inference.")
471499
self.world_size *= self.data_parallel_size
472500

501+
# Initialize stateless group ports for elastic EP
502+
if self.enable_elastic_ep:
503+
num_world_groups = 1
504+
num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
505+
num_ep_groups = max(
506+
1,
507+
self.world_size_across_dp
508+
// (self.data_parallel_size * self.tensor_parallel_size),
509+
)
510+
511+
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
512+
513+
if not self._stateless_world_group_port_list:
514+
all_ports = get_open_ports_list(total_ports_needed + 5)
515+
self._data_parallel_master_port_list = all_ports[-5:]
516+
all_ports = all_ports[:-5]
517+
self._stateless_world_group_port_list = [
518+
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
519+
]
520+
start_idx = num_world_groups * 3
521+
self._stateless_dp_group_port_list = [
522+
all_ports[i : i + 3]
523+
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
524+
]
525+
start_idx += num_dp_groups * 3
526+
self._stateless_ep_group_port_list = [
527+
all_ports[i : i + 3]
528+
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
529+
]
530+
473531
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
474532
# Data parallel was specified in the engine args.
475533
if self.distributed_executor_backend == "external_launcher":

vllm/distributed/device_communicators/all2all.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
3232
debugging.
3333
"""
3434

35-
def __init__(self, cpu_group):
36-
super().__init__(cpu_group)
35+
def __init__(self, cpu_group, tcp_store_group=None):
36+
super().__init__(cpu_group, tcp_store_group)
3737

3838
def naive_multicast(
3939
self,
@@ -105,8 +105,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
105105
all-gather (dispatch) and reduce-scatter (combine).
106106
"""
107107

108-
def __init__(self, cpu_group):
109-
super().__init__(cpu_group)
108+
def __init__(self, cpu_group, tcp_store_group=None):
109+
super().__init__(cpu_group, tcp_store_group)
110110

111111
def dispatch(
112112
self,
@@ -155,13 +155,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
155155
All2All communication based on PPLX kernels.
156156
"""
157157

158-
def __init__(self, cpu_group):
158+
def __init__(self, cpu_group, tcp_store_group=None):
159159
assert has_pplx(), (
160160
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
161161
" to install pplx_kernels."
162162
)
163-
super().__init__(cpu_group)
163+
super().__init__(cpu_group, tcp_store_group)
164+
self.nvshmem_initialized = False
165+
self.handle_cache = Cache()
164166

167+
def get_handle(self, kwargs):
165168
if self.internode:
166169
# inter-node communication needs nvshmem,
167170
# intra-node communication uses p2p mapping directly
@@ -181,17 +184,18 @@ def __init__(self, cpu_group):
181184
if self.rank == 0
182185
else nvshmem_alloc_empty_unique_id()
183186
)
184-
dist.broadcast(
185-
uid,
186-
src=dist.get_process_group_ranks(self.cpu_group)[0],
187-
group=self.cpu_group,
188-
)
187+
if self.tcp_store_group is not None:
188+
uid = self.tcp_store_group.broadcast_obj(uid, src=0)
189+
else:
190+
dist.broadcast(
191+
uid,
192+
src=dist.get_process_group_ranks(self.cpu_group)[0],
193+
group=self.cpu_group,
194+
)
189195
logger.debug("PPLX NVSHMEM UID = %s", uid)
190196
nvshmem_init(uid, self.rank, self.world_size)
197+
self.nvshmem_initialized = True
191198

192-
self.handle_cache = Cache()
193-
194-
def get_handle(self, kwargs):
195199
import pplx_kernels as pplx # type: ignore[import-not-found]
196200

197201
return self.handle_cache.get_or_create(
@@ -231,12 +235,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
231235
All2All communication based on DeepEP High-Throughput kernels.
232236
"""
233237

234-
def __init__(self, cpu_group):
238+
def __init__(self, cpu_group, tcp_store_group=None):
235239
assert has_deep_ep(), (
236240
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
237241
" to install DeepEP kernels."
238242
) # noqa
239-
super().__init__(cpu_group)
243+
super().__init__(cpu_group, tcp_store_group)
240244
self.handle_cache = Cache()
241245

242246
# This is the DeepEP default. Stick to it till we can establish
@@ -268,8 +272,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
268272
All2All communication based on DeepEP High-Throughput kernels.
269273
"""
270274

271-
def __init__(self, cpu_group):
272-
super().__init__(cpu_group)
275+
def __init__(self, cpu_group, tcp_store_group=None):
276+
super().__init__(cpu_group, tcp_store_group)
273277

274278
def _make_all2all_kwargs(self) -> dict[Any, Any]:
275279
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -325,8 +329,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
325329
All2All communication based on DeepEP Low-Latency kernels.
326330
"""
327331

328-
def __init__(self, cpu_group):
329-
super().__init__(cpu_group)
332+
def __init__(self, cpu_group, tcp_store_group=None):
333+
super().__init__(cpu_group, tcp_store_group)
330334

331335
def _make_all2all_kwargs(
332336
self,
@@ -394,11 +398,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
394398
rank: int
395399
world_size: int
396400

397-
def __init__(self, cpu_group):
401+
def __init__(self, cpu_group, tcp_store_group=None):
398402
assert has_flashinfer_all2all(), (
399403
"flashinfer all2all module not found. Please install/check flashinfer"
400404
) # noqa
401-
super().__init__(cpu_group)
405+
super().__init__(cpu_group, tcp_store_group)
402406
logger.debug(
403407
"Initialize for flashinfer All2All rank=%d, world size=%d",
404408
self.rank,

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ class All2AllManagerBase:
2929
rank: int
3030
world_size: int
3131

32-
def __init__(self, cpu_group):
32+
def __init__(self, cpu_group, tcp_store_group=None):
3333
self.cpu_group = cpu_group
34+
self.tcp_store_group = tcp_store_group
3435

3536
# compute some common properties
3637
from vllm.distributed.parallel_state import (
@@ -47,12 +48,17 @@ def __init__(self, cpu_group):
4748
# when we create this object
4849
self.dp_rank = self.dp_group.rank_in_group
4950
self.dp_world_size = self.dp_group.world_size
50-
self.rank = dist.get_rank(cpu_group)
51-
self.world_size = dist.get_world_size(cpu_group)
51+
self.rank = cpu_group.rank()
52+
self.world_size = cpu_group.size()
5253

5354
# all2all communication often has separate implementations for
5455
# intra-node and inter-node communication
55-
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
56+
if tcp_store_group is None:
57+
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
58+
else:
59+
self.internode = not all(
60+
in_the_same_node_as(tcp_store_group, source_rank=0)
61+
)
5662

5763
def get_handle(self, kwargs):
5864
# get a handle for the all2all communication,
@@ -98,17 +104,36 @@ def __init__(
98104
device: torch.device | None = None,
99105
device_group: ProcessGroup | None = None,
100106
unique_name: str = "",
107+
global_ranks: list[int] | None = None,
108+
global_world_size: int | None = None,
101109
):
102110
self.device = device or torch.device("cpu")
103111
self.cpu_group = cpu_group
104112
self.device_group = device_group
105113
self.unique_name = unique_name
106-
self.rank = dist.get_rank(cpu_group)
107-
self.world_size = dist.get_world_size(cpu_group)
108-
self.ranks = dist.get_process_group_ranks(cpu_group)
109-
self.global_rank = dist.get_rank()
110-
self.global_world_size = dist.get_world_size()
111-
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
114+
115+
# Check if this is a stateless process group
116+
from torch.distributed.distributed_c10d import _world
117+
118+
is_stateless = _world.pg_map.get(cpu_group, None) is None
119+
120+
if is_stateless:
121+
# For stateless groups, we can't use torch.distributed methods
122+
self.rank = cpu_group.rank()
123+
self.world_size = cpu_group.size()
124+
assert global_ranks is not None
125+
assert global_world_size is not None
126+
self.ranks = global_ranks
127+
self.global_rank = self.ranks[self.rank]
128+
self.global_world_size = global_world_size
129+
self.rank_in_group = self.rank
130+
else:
131+
self.rank = dist.get_rank(cpu_group)
132+
self.world_size = dist.get_world_size(cpu_group)
133+
self.ranks = dist.get_process_group_ranks(cpu_group)
134+
self.global_rank = dist.get_rank()
135+
self.global_world_size = dist.get_world_size()
136+
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
112137

113138
use_ep = False
114139
all2all_backend = None
@@ -252,6 +277,13 @@ def recv(
252277
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
253278
return tensor
254279

280+
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
281+
"""Broadcast a tensor from source rank to all ranks."""
282+
if self.world_size == 1:
283+
return tensor
284+
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
285+
return tensor
286+
255287
def destroy(self):
256288
pass
257289

@@ -295,3 +327,6 @@ def combine(
295327
This is a no-op in the base class.
296328
"""
297329
return hidden_states
330+
331+
def batch_isend_irecv(self, p2p_ops: list):
332+
raise NotImplementedError

0 commit comments

Comments
 (0)