Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,19 @@ steps:
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'

- label: Elastic EP Scaling Test
timeout_in_minutes: 20
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/distributed/
- vllm/engine/
- vllm/executor/
- vllm/compilation/
- tests/distributed/
commands:
- pytest -v -s distributed/test_elastic_ep.py

- label: Plugin Tests (2 GPUs) # 40min
timeout_in_minutes: 60
mirror_hardwares: [amdexperimental]
Expand Down
80 changes: 80 additions & 0 deletions tests/distributed/test_elastic_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import time

import requests

from vllm.transformers_utils.tokenizer import get_tokenizer

from ..utils import RemoteOpenAIServer, _test_completion, multi_gpu_test

MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"


def _send_scale_command(server: RemoteOpenAIServer, new_dp_size: int) -> bool:
url = server.url_for("scale_elastic_ep")
payload = {"new_data_parallel_size": new_dp_size}
headers = {"Content-Type": "application/json"}

try:
response = requests.post(url, json=payload, headers=headers, timeout=300)
return response.status_code == 200
except requests.exceptions.RequestException:
return False


@multi_gpu_test(num_gpus=4)
def test_elastic_ep_scaling():
vllm_serve_args = [
"--trust-remote-code",
"--disable-log-requests",
"--tensor-parallel-size",
"1",
"--gpu-memory-utilization",
"0.9",
"--max-model-len",
"16384",
"--no-enable-prefix-caching",
"--enable-expert-parallel",
"--all2all-backend",
"pplx",
"--enable-elastic-ep",
"--enable-eplb",
"--eplb-config.num_redundant_experts",
"128",
"--data-parallel-backend",
"ray",
"--data-parallel-size",
"2",
"--data-parallel-size-local",
"2",
"--data-parallel-start-rank",
"0",
]

leader_address = os.environ.get("LEADER_ADDRESS")
if leader_address:
vllm_serve_args.extend(["--data-parallel-address", leader_address])

tokenizer = get_tokenizer(MODEL_NAME, trust_remote_code=True)
prompt = "Hello, my name is"
token_ids = tokenizer(prompt).input_ids

# timeout is 20 minutes
with RemoteOpenAIServer(
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
) as server:
client = server.get_client()
_test_completion(client, MODEL_NAME, prompt, token_ids)

# Scale up from 2->4
assert _send_scale_command(server, 4)
time.sleep(10)
_test_completion(client, MODEL_NAME, prompt, token_ids)

# Scale down from 4->2
assert _send_scale_command(server, 2)
time.sleep(5)
_test_completion(client, MODEL_NAME, prompt, token_ids)
9 changes: 8 additions & 1 deletion tools/ep_kernels/elastic_ep/install_eep_libraries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ if [ -z "$CUDA_HOME" ]; then
exit 1
fi

# assume TORCH_CUDA_ARCH_LIST is set correctly
if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then
echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture."
exit 1
fi

# disable all features except IBGDA
export NVSHMEM_IBGDA_SUPPORT=1

Expand Down Expand Up @@ -82,5 +88,6 @@ git clone https://github.com/ppl-ai/pplx-kernels
cd pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v
git checkout 12cecfd
PIP_NO_BUILD_ISOLATION=0 pip install . --no-deps -v

31 changes: 31 additions & 0 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,34 @@ def _dispatch_to_compiled_code(self):
yield
finally:
self.__class__.forward.__code__ = original


def reset_compile_wrapper(model: torch.nn.Module) -> None:
"""
Clean up compiled model and captured CUDA graphs for elastic EP.
"""
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
model, "model"
):
model = model.model
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
return
# model.do_not_compile is set by the @support_torch_compile decorator
if hasattr(model, "do_not_compile") and model.do_not_compile:
return
from vllm.compilation.counter import compilation_counter

# reset the compilation counter
compilation_counter.num_models_seen = 0
compilation_counter.num_graphs_seen = 0
compilation_counter.num_piecewise_graphs_seen = 0
compilation_counter.num_piecewise_capturable_graphs_seen = 0
compilation_counter.num_backend_compilations = 0
compilation_counter.num_gpu_runner_capture_triggers = 0
compilation_counter.num_cudagraph_captured = 0
compilation_counter.num_inductor_compiles = 0
compilation_counter.num_eager_compiles = 0
compilation_counter.num_cache_entries_updated = 0
compilation_counter.num_compiled_artifacts_saved = 0
compilation_counter.stock_torch_compile_count = 0
TorchCompileWithNoGuardsWrapper.__init__(model)
82 changes: 80 additions & 2 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ParallelConfig:
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"nixl_ep",
"allgather_reducescatter",
"flashinfer_all2allv",
]
Expand All @@ -150,6 +151,9 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""

enable_elastic_ep: bool = False
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""

enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""

Expand Down Expand Up @@ -223,6 +227,29 @@ class is dynamically inherited by the worker class. This is used to inject
Set to be private as it's not intended to be configured by users.
"""

_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""

_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""

_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""

decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
Expand Down Expand Up @@ -342,7 +369,16 @@ def get_next_dp_init_port(self) -> int:

return answer

def stateless_init_dp_group(self) -> ProcessGroup:
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop()

def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop()

def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop()

def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
Expand All @@ -366,7 +402,8 @@ def stateless_init_dp_group(self) -> ProcessGroup:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend=current_platform.dist_backend,
backend="gloo",
return_store=return_store,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
Expand Down Expand Up @@ -398,6 +435,7 @@ def use_sequence_parallel_moe(self) -> bool:
"naive",
"deepep_high_throughput",
"deepep_low_latency",
"nixl_ep",
)
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
Expand Down Expand Up @@ -511,6 +549,46 @@ def __post_init__(self) -> None:
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size

# Initialize stateless group ports for elastic EP
if self.enable_elastic_ep:
if not self.enable_eplb:
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
num_world_groups = 1
dp_size = self.data_parallel_size
ep_size = self.data_parallel_size * self.world_size_across_dp
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
num_ep_groups = max(1, self.world_size_across_dp // ep_size)

# NOTE(yongji):
# we need 3 ports for each comm group in `StatelessGroupCoordinator`.
# one for stateless CPU group, one for stateless device group,
# one for stateless TCPStore group.
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
if not self._stateless_world_group_port_list:
all_ports = get_open_ports_list(total_ports_needed + 5)
# NOTE(yongji): allocate 5 ports for _data_parallel_master_port_list
# as in the case when elastic EP is not enabled
# (the regular DP code path below this if: `get_open_ports_list(5)`).
# We must set _data_parallel_master_port_list here instead of
# letting the regular DP code path to set it, since
# we should call get_open_ports_list() only once
# to ensure the allocated ports are distinct.
self._data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
self._stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]

if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
Expand Down
Loading
Loading