Skip to content

Commit 452eeb9

Browse files
committed
eep phase: stateless group + CUDA graph support
support request serving during scaling up/down Signed-off-by: Yongji Wu <[email protected]> misc fixes Signed-off-by: Yongji Wu <[email protected]> minor fix Signed-off-by: Yongji Wu <[email protected]> minor fix Signed-off-by: Yongji Wu <[email protected]> scaling test: 2->4->2 Signed-off-by: Yongji Wu <[email protected]> tiny fix Signed-off-by: Yongji Wu <[email protected]> rebase fix Signed-off-by: Yongji Wu <[email protected]> rebase fix Signed-off-by: Yongji Wu <[email protected]> rebase fix Signed-off-by: Yongji Wu <[email protected]> rebase fix Signed-off-by: Yongji Wu <[email protected]> rebase fix Signed-off-by: Yongji Wu <[email protected]> small fix Signed-off-by: Yongji Wu <[email protected]> small fix Signed-off-by: Yongji Wu <[email protected]>
1 parent c17610e commit 452eeb9

34 files changed

+2785
-682
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,19 @@ steps:
11471147
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
11481148
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
11491149

1150+
- label: Elastic EP Scaling Test
1151+
timeout_in_minutes: 20
1152+
working_dir: "/vllm-workspace/tests"
1153+
num_gpus: 4
1154+
source_file_dependencies:
1155+
- vllm/distributed/
1156+
- vllm/engine/
1157+
- vllm/executor/
1158+
- vllm/compilation/
1159+
- tests/distributed/
1160+
commands:
1161+
- pytest -v -s distributed/test_elastic_ep.py
1162+
11501163
- label: Plugin Tests (2 GPUs) # 40min
11511164
timeout_in_minutes: 60
11521165
mirror_hardwares: [amdexperimental]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
import time
6+
7+
import requests
8+
9+
from vllm.transformers_utils.tokenizer import get_tokenizer
10+
11+
from ..utils import RemoteOpenAIServer, _test_completion, multi_gpu_test
12+
13+
MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
14+
15+
16+
def _send_scale_command(server: RemoteOpenAIServer, new_dp_size: int) -> bool:
17+
url = server.url_for("scale_elastic_ep")
18+
payload = {"new_data_parallel_size": new_dp_size}
19+
headers = {"Content-Type": "application/json"}
20+
21+
try:
22+
response = requests.post(url, json=payload, headers=headers, timeout=300)
23+
return response.status_code == 200
24+
except requests.exceptions.RequestException:
25+
return False
26+
27+
28+
@multi_gpu_test(num_gpus=4)
29+
def test_elastic_ep_scaling():
30+
vllm_serve_args = [
31+
"--trust-remote-code",
32+
"--disable-log-requests",
33+
"--tensor-parallel-size",
34+
"1",
35+
"--gpu-memory-utilization",
36+
"0.9",
37+
"--max-model-len",
38+
"16384",
39+
"--no-enable-prefix-caching",
40+
"--enable-expert-parallel",
41+
"--all2all-backend",
42+
"pplx",
43+
"--enable-elastic-ep",
44+
"--enable-eplb",
45+
"--eplb-config.num_redundant_experts",
46+
"128",
47+
"--data-parallel-backend",
48+
"ray",
49+
"--data-parallel-size",
50+
"2",
51+
"--data-parallel-size-local",
52+
"2",
53+
"--data-parallel-start-rank",
54+
"0",
55+
]
56+
57+
leader_address = os.environ.get("LEADER_ADDRESS")
58+
if leader_address:
59+
vllm_serve_args.extend(["--data-parallel-address", leader_address])
60+
61+
tokenizer = get_tokenizer(MODEL_NAME, trust_remote_code=True)
62+
prompt = "Hello, my name is"
63+
token_ids = tokenizer(prompt).input_ids
64+
65+
# timeout is 20 minutes
66+
with RemoteOpenAIServer(
67+
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
68+
) as server:
69+
client = server.get_client()
70+
_test_completion(client, MODEL_NAME, prompt, token_ids)
71+
72+
# Scale up from 2->4
73+
assert _send_scale_command(server, 4)
74+
time.sleep(10)
75+
_test_completion(client, MODEL_NAME, prompt, token_ids)
76+
77+
# Scale down from 4->2
78+
assert _send_scale_command(server, 2)
79+
time.sleep(5)
80+
_test_completion(client, MODEL_NAME, prompt, token_ids)

tools/ep_kernels/elastic_ep/install_eep_libraries.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ if [ -z "$CUDA_HOME" ]; then
5252
exit 1
5353
fi
5454

55+
# assume TORCH_CUDA_ARCH_LIST is set correctly
56+
if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then
57+
echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture."
58+
exit 1
59+
fi
60+
5561
# disable all features except IBGDA
5662
export NVSHMEM_IBGDA_SUPPORT=1
5763

@@ -82,5 +88,6 @@ git clone https://github.com/ppl-ai/pplx-kernels
8288
cd pplx-kernels
8389
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
8490
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
85-
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v
91+
git checkout 12cecfd
92+
PIP_NO_BUILD_ISOLATION=0 pip install . --no-deps -v
8693

vllm/compilation/wrapper.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,34 @@ def _dispatch_to_compiled_code(self):
258258
yield
259259
finally:
260260
self.__class__.forward.__code__ = original
261+
262+
263+
def reset_compile_wrapper(model: torch.nn.Module) -> None:
264+
"""
265+
Clean up compiled model and captured CUDA graphs for elastic EP.
266+
"""
267+
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
268+
model, "model"
269+
):
270+
model = model.model
271+
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
272+
return
273+
# model.do_not_compile is set by the @support_torch_compile decorator
274+
if hasattr(model, "do_not_compile") and model.do_not_compile:
275+
return
276+
from vllm.compilation.counter import compilation_counter
277+
278+
# reset the compilation counter
279+
compilation_counter.num_models_seen = 0
280+
compilation_counter.num_graphs_seen = 0
281+
compilation_counter.num_piecewise_graphs_seen = 0
282+
compilation_counter.num_piecewise_capturable_graphs_seen = 0
283+
compilation_counter.num_backend_compilations = 0
284+
compilation_counter.num_gpu_runner_capture_triggers = 0
285+
compilation_counter.num_cudagraph_captured = 0
286+
compilation_counter.num_inductor_compiles = 0
287+
compilation_counter.num_eager_compiles = 0
288+
compilation_counter.num_cache_entries_updated = 0
289+
compilation_counter.num_compiled_artifacts_saved = 0
290+
compilation_counter.stock_torch_compile_count = 0
291+
TorchCompileWithNoGuardsWrapper.__init__(model)

vllm/config/parallel.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ class ParallelConfig:
166166
disable_custom_all_reduce: bool = False
167167
"""Disable the custom all-reduce kernel and fall back to NCCL."""
168168

169+
enable_elastic_ep: bool = False
170+
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
171+
169172
enable_dbo: bool = False
170173
"""Enable dual batch overlap for the model executor."""
171174

@@ -239,6 +242,29 @@ class is dynamically inherited by the worker class. This is used to inject
239242
Set to be private as it's not intended to be configured by users.
240243
"""
241244

245+
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
246+
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
247+
Set to be private as it's not intended to be configured by users.
248+
It is a list of list[int], with each inner list contains a set of 3 ports
249+
to be used for setting up the stateless CPU/device/TCPStore groups
250+
in StatelessGroupCoordinator. The number of inner lists is equal to
251+
the number of DP groups,
252+
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
253+
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
254+
"""
255+
256+
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
257+
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
258+
Set to be private as it's not intended to be configured by users.
259+
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
260+
"""
261+
262+
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
263+
"""List of open ports for stateless world group when enable_elastic_ep is True.
264+
Set to be private as it's not intended to be configured by users.
265+
len(self._stateless_world_group_port_list) == 1,
266+
"""
267+
242268
decode_context_parallel_size: int = 1
243269
"""Number of decode context parallel groups, because the world size does
244270
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
@@ -358,7 +384,16 @@ def get_next_dp_init_port(self) -> int:
358384

359385
return answer
360386

361-
def stateless_init_dp_group(self) -> ProcessGroup:
387+
def get_next_stateless_world_group_port(self) -> list[int]:
388+
return self._stateless_world_group_port_list.pop()
389+
390+
def get_next_stateless_dp_group_port(self) -> list[int]:
391+
return self._stateless_dp_group_port_list.pop()
392+
393+
def get_next_stateless_ep_group_port(self) -> list[int]:
394+
return self._stateless_ep_group_port_list.pop()
395+
396+
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
362397
# NOTE: In high-concurrency scenarios multiple processes
363398
# can pick the same (currently free) port through a race
364399
# condition when calling `get_open_port()`. When the first
@@ -382,7 +417,8 @@ def stateless_init_dp_group(self) -> ProcessGroup:
382417
self.get_next_dp_init_port(),
383418
self.data_parallel_rank,
384419
self.data_parallel_size,
385-
backend=current_platform.dist_backend,
420+
backend="gloo",
421+
return_store=return_store,
386422
)
387423
except DistNetworkError as e:
388424
# We only want to retry when the root cause is EADDRINUSE.
@@ -561,6 +597,46 @@ def __post_init__(self) -> None:
561597
logger.info("Using external launcher for distributed inference.")
562598
self.world_size *= self.data_parallel_size
563599

600+
# Initialize stateless group ports for elastic EP
601+
if self.enable_elastic_ep:
602+
if not self.enable_eplb:
603+
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
604+
num_world_groups = 1
605+
dp_size = self.data_parallel_size
606+
ep_size = self.data_parallel_size * self.world_size_across_dp
607+
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
608+
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
609+
610+
# NOTE(yongji):
611+
# we need 3 ports for each comm group in `StatelessGroupCoordinator`.
612+
# one for stateless CPU group, one for stateless device group,
613+
# one for stateless TCPStore group.
614+
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
615+
if not self._stateless_world_group_port_list:
616+
all_ports = get_open_ports_list(total_ports_needed + 5)
617+
# NOTE(yongji): allocate 5 ports for _data_parallel_master_port_list
618+
# as in the case when elastic EP is not enabled
619+
# (the regular DP code path below this if: `get_open_ports_list(5)`).
620+
# We must set _data_parallel_master_port_list here instead of
621+
# letting the regular DP code path to set it, since
622+
# we should call get_open_ports_list() only once
623+
# to ensure the allocated ports are distinct.
624+
self._data_parallel_master_port_list = all_ports[-5:]
625+
all_ports = all_ports[:-5]
626+
self._stateless_world_group_port_list = [
627+
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
628+
]
629+
start_idx = num_world_groups * 3
630+
self._stateless_dp_group_port_list = [
631+
all_ports[i : i + 3]
632+
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
633+
]
634+
start_idx += num_dp_groups * 3
635+
self._stateless_ep_group_port_list = [
636+
all_ports[i : i + 3]
637+
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
638+
]
639+
564640
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
565641
# Data parallel was specified in the engine args.
566642
if self.distributed_executor_backend == "external_launcher":

vllm/distributed/device_communicators/all2all.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
3232
debugging.
3333
"""
3434

35-
def __init__(self, cpu_group):
36-
super().__init__(cpu_group)
35+
def __init__(self, cpu_group, tcp_store_group=None):
36+
super().__init__(cpu_group, tcp_store_group)
3737

3838
def naive_multicast(
3939
self,
@@ -105,8 +105,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
105105
all-gather (dispatch) and reduce-scatter (combine).
106106
"""
107107

108-
def __init__(self, cpu_group):
109-
super().__init__(cpu_group)
108+
def __init__(self, cpu_group, tcp_store_group=None):
109+
super().__init__(cpu_group, tcp_store_group)
110110

111111
def dispatch(
112112
self,
@@ -155,13 +155,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
155155
All2All communication based on PPLX kernels.
156156
"""
157157

158-
def __init__(self, cpu_group):
158+
def __init__(self, cpu_group, tcp_store_group=None):
159159
assert has_pplx(), (
160160
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
161161
" to install pplx_kernels."
162162
)
163-
super().__init__(cpu_group)
163+
super().__init__(cpu_group, tcp_store_group)
164+
self.nvshmem_initialized = False
165+
self.handle_cache = Cache()
164166

167+
def get_handle(self, kwargs):
165168
if self.internode:
166169
# inter-node communication needs nvshmem,
167170
# intra-node communication uses p2p mapping directly
@@ -181,17 +184,18 @@ def __init__(self, cpu_group):
181184
if self.rank == 0
182185
else nvshmem_alloc_empty_unique_id()
183186
)
184-
dist.broadcast(
185-
uid,
186-
src=dist.get_process_group_ranks(self.cpu_group)[0],
187-
group=self.cpu_group,
188-
)
187+
if self.tcp_store_group is not None:
188+
uid = self.tcp_store_group.broadcast_obj(uid, src=0)
189+
else:
190+
dist.broadcast(
191+
uid,
192+
src=dist.get_process_group_ranks(self.cpu_group)[0],
193+
group=self.cpu_group,
194+
)
189195
logger.debug("PPLX NVSHMEM UID = %s", uid)
190196
nvshmem_init(uid, self.rank, self.world_size)
197+
self.nvshmem_initialized = True
191198

192-
self.handle_cache = Cache()
193-
194-
def get_handle(self, kwargs):
195199
import pplx_kernels as pplx # type: ignore[import-not-found]
196200

197201
return self.handle_cache.get_or_create(
@@ -231,12 +235,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
231235
All2All communication based on DeepEP High-Throughput kernels.
232236
"""
233237

234-
def __init__(self, cpu_group):
238+
def __init__(self, cpu_group, tcp_store_group=None):
235239
assert has_deep_ep(), (
236240
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
237241
" to install DeepEP kernels."
238242
) # noqa
239-
super().__init__(cpu_group)
243+
super().__init__(cpu_group, tcp_store_group)
240244
self.handle_cache = Cache()
241245

242246
# This is the DeepEP default. Stick to it till we can establish
@@ -268,8 +272,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
268272
All2All communication based on DeepEP High-Throughput kernels.
269273
"""
270274

271-
def __init__(self, cpu_group):
272-
super().__init__(cpu_group)
275+
def __init__(self, cpu_group, tcp_store_group=None):
276+
super().__init__(cpu_group, tcp_store_group)
273277

274278
def _make_all2all_kwargs(self) -> dict[Any, Any]:
275279
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -325,8 +329,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
325329
All2All communication based on DeepEP Low-Latency kernels.
326330
"""
327331

328-
def __init__(self, cpu_group):
329-
super().__init__(cpu_group)
332+
def __init__(self, cpu_group, tcp_store_group=None):
333+
super().__init__(cpu_group, tcp_store_group)
330334

331335
def _make_all2all_kwargs(
332336
self,
@@ -396,11 +400,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
396400
rank: int
397401
world_size: int
398402

399-
def __init__(self, cpu_group):
403+
def __init__(self, cpu_group, tcp_store_group=None):
400404
assert has_flashinfer_all2all(), (
401405
"flashinfer all2all module not found. Please install/check flashinfer"
402406
) # noqa
403-
super().__init__(cpu_group)
407+
super().__init__(cpu_group, tcp_store_group)
404408
logger.debug(
405409
"Initialize for flashinfer All2All rank=%d, world size=%d",
406410
self.rank,

0 commit comments

Comments
 (0)