Skip to content

Conversation

@libertyeagle
Copy link

@libertyeagle libertyeagle commented Oct 6, 2025

Purpose

PR #20775 introduces initial support of elastic expert parallelism. This PR adds further optimizations towards Milestone 2 in #20323. Key features include:

  • Breakdown the scale up/down logic into a state machine of multiple stages, with their execution controlled in vllm/distributed/elastic_ep/elastic_state.py and vllm/distributed/elastic_ep/elastic_execute.py.
  • Newly started workers receive all weights (non-MoE modules and expert weights) from peer GPUs.
  • We no longer need to drop traffic during scale up/down. During scale up, existing workers can continue to serve requests until new workers are ready (non-expert weights are already received and prepare to compile/warmup the model). Existing workers will progressively reconfigure to new EP size in DPEngineCoreProc. In run_busy_loop, elastic_scaling_state.progress() is called to progress reconfiguration by one step if ready. If reconfiguration cannot continue, existing workers continue to serve requests. Such progressive reconfiguration between forward steps also helps to quickly finish in-flight user requests, prevent requests from queuing up and improve SLO attainment.
  • If elastic EP is enabled (—enable-elastic-ep), all EP/DP communicators will be replaced by vllm/distributed/stateless_coordinator.py that is independent of torch.distributed’s global state. We can therefore create standby communicators while keeping the current ones, enabling the bootstrap of new workers to overlap with request serving on existing workers. We only need to do a single switch to use the new communicators after we are ready to switch to the new setup.
  • For scale-up, delay EPLB reshuffle until reconfiguration is finished. Newly joined workers can dispatch tokens to the original set of GPUs for expert computation, while experts can be progressively reshuffled to include the newly joined GPUs.
  • Support for enabling CUDA graphs, which is critical to performance especially in decode mode. In this PR, on existing workers, we will destroy compiled model and all captured CUDA graphs, followed by recompiling and recapturing all graphs. See switch_and_prepare() in vllm/distributed/elastic_ep/elastic_execute.py. We will introduce optimizations on CUDA graphs in follow-up PRs.

There are also some minor bug fixes including:

  • Fix ray resources discovery and engine zmq addr when scaling from intra-node to inter-node settings.
  • Fix the issue that throughput logging is not reported after scale up.

Test Plan

We test the performance before scale up and after scale on using Qwen/Qwen3-30B-A3B-Thinking-2507-FP8. The number of physical experts per GPU is set to 72. We note that the number of local physical experts remain the same during scale up and down, while the total number of redundant experts scales accordingly, which is the same assumption as in PR #20775. We use PPLX kernels (intra-node mode that does not require NVSHMEM) and enable CUDA graphs using default settings.

MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
vllm serve $MODEL_NAME --trust-remote-code \
    --disable-log-requests \
    --host $HOST \
    --port $PORT \
    --tensor-parallel-size 1 \
    --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
    --max-model-len $MAX_MODEL_LEN \
    --no-enable-prefix-caching \
    --enable-expert-parallel \
    --enable-elastic-ep \
    --enable-eplb \
    --eplb-config.num_redundant_experts $NUM_REDUNDANT_EXPERTS \
    --eplb-config.window_size $EPLB_WINDOW_SIZE \
    --eplb-config.step_interval $EPLB_STEP_INTERVAL \
    --data-parallel-backend ray \
    --data-parallel-size $DATA_PARALLEL_SIZE \
    --data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
    --data-parallel-address $LEADER_ADDRESS \
    --data-parallel-rpc-port 9876 \
    --data-parallel-start-rank 0

To scale up we use:

python examples/online_serving/elastic_ep/scale.py --host $HOST --port $PORT --new-dp-size $NEW_DATA_PARALLEL_SIZE

Test Results

We use the following benchmark script.

vllm bench serve \
    --model $MODEL_NAME \
    --host $HOST \
    --port $PORT \
    --dataset-name random \
    --random-input-len 256 \
    --random-output-len 128 \
    --num-prompts 512

Serving on 2 GPUs (EP=2, TP=1) before scaling up:

============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  15.85     
Total input tokens:                      130815    
Total generated tokens:                  65478     
Request throughput (req/s):              32.30     
Output token throughput (tok/s):         4131.03   
Peak output token throughput (tok/s):    17408.00  
Peak concurrent requests:                512.00    
Total Token throughput (tok/s):          12384.18  
---------------Time to First Token----------------
Mean TTFT (ms):                          6870.52   
Median TTFT (ms):                        7559.63   
P99 TTFT (ms):                           12107.77  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          69.94     
Median TPOT (ms):                        64.56     
P99 TPOT (ms):                           109.25    
---------------Inter-token Latency----------------
Mean ITL (ms):                           69.90     
Median ITL (ms):                         29.54     
P99 ITL (ms):                            1443.20   
==================================================

Serving on 4 GPUs (EP=4, TP=1) after scaling up:

============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  9.89      
Total input tokens:                      130815    
Total generated tokens:                  65415     
Request throughput (req/s):              51.75     
Output token throughput (tok/s):         6612.23   
Peak output token throughput (tok/s):    18802.00  
Peak concurrent requests:                512.00    
Total Token throughput (tok/s):          19835.17  
---------------Time to First Token----------------
Mean TTFT (ms):                          4089.23   
Median TTFT (ms):                        4812.20   
P99 TTFT (ms):                           6322.47   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.82     
Median TPOT (ms):                        44.26     
P99 TPOT (ms):                           62.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.91     
Median ITL (ms):                         27.23     
P99 ITL (ms):                            1481.01   
==================================================

Next Steps

  • PR 2/N: Support elastic EP kernels and weight communicators (e.g., P2P transfer engines like Mooncake and NIXL).
  • PR 3/N: CUDA graph capture cost optimization: enabling incremental CUDA graph updates while serving traffic, enabling CUDA graph memory pool optimizations to minimize new memory allocation during CUDA graph updates.
  • PR N/N: Further cost optimization (e.g., torch.compile cache management, incremental EPLB and incremental non-expert weight transfer); support more kernels (e.g., regular DeepEP), scheduler optimization to migrate dispatched requests to newly started workers for load balancing; …

CC List

@abmfy @ruisearch42 @simon-mo @tlrmchlsmth @njhill @kouroshHakha

@github-actions
Copy link

github-actions bot commented Oct 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Oct 6, 2025
@mergify
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @libertyeagle.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 6, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations for elastic expert parallelism, building upon initial support. The key changes include a new state machine for scaling up/down, peer-to-peer weight transfer for new workers, and progressive reconfiguration to avoid dropping traffic during scaling operations. The introduction of stateless communicators independent of torch.distributed's global state is a major architectural shift enabling these features. My review has identified a critical bug in the state machine logic and several high-severity issues related to fragile implementation details that could lead to future breakages. Overall, this is a substantial and well-structured contribution, but the identified issues should be addressed to ensure robustness and correctness.

Comment on lines 256 to 338
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop(0)

def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop(0)

def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These methods use pop(0) to retrieve a port from a list without checking if the list is empty. If the port lists (_stateless_world_group_port_list, _stateless_dp_group_port_list, _stateless_ep_group_port_list) are exhausted for any reason, this will raise an IndexError and crash the process. While the logic in __post_init__ seems to pre-allocate the necessary ports, this design is fragile. A more robust implementation would be to check if the list is empty before popping and raise a more informative error message.

Comment on lines 98 to 118
# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check _world.pg_map.get(cpu_group, None) is None relies on an internal, undocumented implementation detail of torch.distributed to determine if a process group is stateless. This is a brittle approach that could break with future PyTorch updates. It would be more robust to use an explicit mechanism to identify stateless groups, such as a custom process group class that carries this information, or passing a flag during initialization.

Comment on lines 307 to 416
if op.op.__name__ == "isend":
self.send(op.tensor, op.group_peer, stream)
elif op.op.__name__ == "irecv":
self.recv(op.tensor, op.group_peer, stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Checking op.op.__name__ to determine the operation type is fragile. The name of a function can change, or it could be wrapped by a decorator, which would break this logic. It's more robust to check for function identity directly.

Suggested change
if op.op.__name__ == "isend":
self.send(op.tensor, op.group_peer, stream)
elif op.op.__name__ == "irecv":
self.recv(op.tensor, op.group_peer, stream)
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Gemini's suggestion is a good one, if valid

Comment on lines +143 to +149
if ep_group not in _world.pg_map:
ep_group = get_ep_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if ep_group not in _world.pg_map: relies on an internal implementation detail of PyTorch's distributed library (_world.pg_map) to detect stateless process groups. This is not a public API and is subject to change without notice, which makes this code brittle. A more robust approach, such as using a custom process group class or an explicit flag, should be used to differentiate between stateful and stateless groups.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally agree with the bot - could we find a better way to detect this?

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass review

self.available_gpu_memory_for_kv_cache = -1

if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
self._elastic_scale_up_post_init()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens as part of init, rather than after init, maybe rename?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to _eep_scale_up_before_kv_init

timeout=timeout,
)

if isinstance(group_name, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if group_name: ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to if group_name is not None

Comment on lines 69 to 78
self.new_dp_group = (
self.engine_core.dp_group if worker_type == "new" else new_parallel_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is new_parallel_config assigned to self.new_dp_group?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to self.new_dp_group_or_config. ParallelConfig is passed in only for existing worker when standby group is to be created.

notification_type == "NEW_WORKERS_INIT_READY"
and self.state == ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_INIT
):
self.waiting_for_notification = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this? Can we make it a property of self.state to simplify the logic here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed self.waiting_for_notification.

TRANSFER_EXPERT_MAPPING = 2
WAIT_NEW_WORKERS_WEIGHTS_INIT = 3
TRANSFER_WEIGHTS = 4
SYNC_KV_CACHE_MEMORY = 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:SYNC_KV_CACHE_SIZE?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to SYNC_KV_CACHE_MEMORY_SIZE.

@mergify
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @libertyeagle.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@libertyeagle
Copy link
Author

Updated to fix scale-down bugs and synchronization issues when serving requests during scaling. cc: @ruisearch42

@mergify mergify bot removed the needs-rebase label Nov 24, 2025
@libertyeagle libertyeagle force-pushed the eep-m2 branch 2 times, most recently from 949997f to 651ebdb Compare November 24, 2025 23:25
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
raise RuntimeError("Current vLLM config is not set.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this debug cruft? Either delete this line or the following two.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted.

Comment on lines 596 to 588
# Initialize stateless group ports for elastic EP
if self.enable_elastic_ep:
if not self.enable_eplb:
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
num_world_groups = 1
num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
num_ep_groups = max(
1,
self.world_size_across_dp
// (self.data_parallel_size * self.tensor_parallel_size),
)

total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3

if not self._stateless_world_group_port_list:
all_ports = get_open_ports_list(total_ports_needed + 5)
# NOTE(yongji): allocate 5 ports for _data_parallel_master_port_list
# as in the case when elastic EP is not enabled
# (the regular DP code path below this if).
# We must set _data_parallel_master_port_list here instead of
# letting the regular DP code path to set it, since
# we should call get_open_ports_list() only once
# to ensure the allocated ports are distinct.
self._data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
self._stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I said this on an earlier version of this PR but I think this code could explained better, and hopefully simplified.

Please add a comment enumerating

  • What the 5 ports at the end of self._data_parallel_master_port_list are used for.
  • What the 3 ports for each world, dp, and ep group are used for.

Instead of this

            num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
            num_ep_groups = max(
                1,
                self.world_size_across_dp
                // (self.data_parallel_size * self.tensor_parallel_size),
            )

Could we do the following?

            num_dp_groups = get_dp_group().world_size
            num_ep_groups = get_ep_group().world_size

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the comments for the meaning of the ports in docstring. Do you think this is enough?

    _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
    """List of open ports for stateless DP groups when enable_elastic_ep is True.
    Set to be private as it's not intended to be configured by users.
    It is a list of list[int], with each inner list contains a set of 3 ports
    to be used for setting up the stateless CPU/device/TCPStore groups
    in StatelessGroupCoordinator. The number of inner lists is equal to
    the number of DP groups, 
    i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
    and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
    """

    _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
    """List of open ports for stateless EP groups when enable_elastic_ep is True.
    Set to be private as it's not intended to be configured by users.
    len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
    """

    _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
    """List of open ports for stateless world group when enable_elastic_ep is True.
    Set to be private as it's not intended to be configured by users.
    len(self._stateless_world_group_port_list) == 1,
    """

We cannot use get_dp_group().world_size here because DP group is not created at this point. These ports are allocated for creating DP/EP groups.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use get_dp_group().world_size here because DP group is not created at this point. These ports are allocated for creating DP/EP groups.

OK - in that case, could we at seat name them ep_group_world_size and dp_group_world_size? I think that's much more clear.

On the docstring comments on the port lists: that's a fine place to document them, but it doesn't explain what the three ports per member of the dp_group are used for (why three?).
Also it seems to be inconsistent with this line

total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree it would be good to also point out here what the 3 ports are used for (i.e., why *3).

# NOTE(yongji):
# we need 3 ports for each comm group in `StatelessGroupCoordinator`.
# one for stateless CPU group, one for stateless device group,
# one for stateless TCPStore group.
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
.
I explained it in the docstring a bit that the ports are used respectively for CPU/device/TCPStore group but it may not be quite clear in the docstring.
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the last 5 ports for DP master ports,

# NOTE(yongji): allocate 5 ports for _data_parallel_master_port_list
# as in the case when elastic EP is not enabled
# (the regular DP code path below this if: `get_open_ports_list(5)`).
# We must set _data_parallel_master_port_list here instead of
# letting the regular DP code path to set it, since
# we should call get_open_ports_list() only once
# to ensure the allocated ports are distinct.
self._data_parallel_master_port_list = all_ports[-5:]

It is needed as in the regular code path without elastic EP.

if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()

Comment on lines 307 to 416
if op.op.__name__ == "isend":
self.send(op.tensor, op.group_peer, stream)
elif op.op.__name__ == "irecv":
self.recv(op.tensor, op.group_peer, stream)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Gemini's suggestion is a good one, if valid

@mergify
Copy link

mergify bot commented Nov 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @libertyeagle.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 25, 2025
Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvements in the new iterations!


@dataclass
class ElasticScalingCache:
existing_workers: list[EngineIdentity]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing_core_engines to be precise

@dataclass
class ElasticScalingCache:
existing_workers: list[EngineIdentity]
num_new_workers: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_new_core_engines

class ElasticScalingCache:
existing_workers: list[EngineIdentity]
num_new_workers: int
pending_notifications: dict[str, set[int]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on what's the key and value.

This is also not "pending" notifications, but received notifications, right?

self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
self.eep_scaling_cache = ElasticScalingCache(
existing_workers=self.core_engines.copy(),
num_new_workers=new_data_parallel_size - cur_data_parallel_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So num_new_workers can be negative, let's add a comment.
Or maybe we should call it num_core_engines_delta



@dataclass
class ElasticScalingCache:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would PendingElasticScaling be a better name?

world_size = parallel_config.world_size
new_world_size_across_dp = world_size * new_data_parallel_size
num_world_groups = 1
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is this just max(1, world_size)?

Comment on lines +234 to +236
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of DP groups should be dp_size?

Comment on lines +554 to +564
num_world_groups = 1
dp_size = self.data_parallel_size
ep_size = self.data_parallel_size * self.world_size_across_dp
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
num_ep_groups = max(1, self.world_size_across_dp // ep_size)

# NOTE(yongji):
# we need 3 ports for each comm group in `StatelessGroupCoordinator`.
# one for stateless CPU group, one for stateless device group,
# one for stateless TCPStore group.
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part has duplicate logic with allocate_stateless_group_ports()? Is it possible to extract the common functionality?

assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_up_elastic_ep(
self.vllm_config, new_data_parallel_size
parallel_config.eplb_config.num_redundant_experts = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment why

support request serving during scaling up/down

Signed-off-by: Yongji Wu <[email protected]>

misc fixes

Signed-off-by: Yongji Wu <[email protected]>

minor fix

Signed-off-by: Yongji Wu <[email protected]>

minor fix

Signed-off-by: Yongji Wu <[email protected]>

scaling test: 2->4->2

Signed-off-by: Yongji Wu <[email protected]>

tiny fix

Signed-off-by: Yongji Wu <[email protected]>

rebase fix

Signed-off-by: Yongji Wu <[email protected]>

rebase fix

Signed-off-by: Yongji Wu <[email protected]>

rebase fix

Signed-off-by: Yongji Wu <[email protected]>

rebase fix

Signed-off-by: Yongji Wu <[email protected]>

rebase fix

Signed-off-by: Yongji Wu <[email protected]>

small fix

Signed-off-by: Yongji Wu <[email protected]>

small fix

Signed-off-by: Yongji Wu <[email protected]>

small fix

Signed-off-by: Yongji Wu <[email protected]>

rebase fix

Signed-off-by: Yongji Wu <[email protected]>
@xeonliu
Copy link

xeonliu commented Nov 28, 2025

@libertyeagle
Execuse me, I wanna ask a question.
Is it true that during elastic scaling scale-up transitions, when the expert mapping is updated to assign experts to new (not-yet-initialized) ranks, there is a window where:

  • Existing engines cannot access experts assigned exclusively to the new ranks
  • Requests requiring these experts cannot be properly served
  • The system does not handle this "missing experts" scenario gracefully

I read the code and I see during scale up
Phase 1 will do EPLB
Phase 2 the new enginecore is created

I wonder what will happen when eplb mapping is changed but new engines are not yet started.

Thanks for any clearification!

@libertyeagle
Copy link
Author

@libertyeagle Execuse me, I wanna ask a question. Is it true that during elastic scaling scale-up transitions, when the expert mapping is updated to assign experts to new (not-yet-initialized) ranks, there is a window where:

  • Existing engines cannot access experts assigned exclusively to the new ranks
  • Requests requiring these experts cannot be properly served
  • The system does not handle this "missing experts" scenario gracefully

I read the code and I see during scale up Phase 1 will do EPLB Phase 2 the new enginecore is created

I wonder what will happen when eplb mapping is changed but new engines are not yet started.

Thanks for any clearification!

That wouldn't be an issue. When a new set of $M$ engines are created, they initially have no expert weights, and the EP dispatch will dispatch tokens to the original set of $N$ engines. Only after the experts are reshuffled with another EPLB, tokens are then dispatched to the new set of engines.

@mergify
Copy link

mergify bot commented Dec 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @libertyeagle.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 1, 2025
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our conversation last Wednesday, we realized that this ElasticEP is incompatible with
--data-parallel-hybrid-lb and --data-parallel-external-lb. (This is because we are relying on the single API server and core client to coordinate scale_up/scale_down)

Could you raise a NotImplementedError in arg_utils.py when this happens?

Comment on lines +65 to +80
# timeout is 20 minutes
with RemoteOpenAIServer(
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
) as server:
client = server.get_client()
_test_completion(client, MODEL_NAME, prompt, token_ids)

# Scale up from 2->4
assert _send_scale_command(server, 4)
time.sleep(10)
_test_completion(client, MODEL_NAME, prompt, token_ids)

# Scale down from 4->2
assert _send_scale_command(server, 2)
time.sleep(5)
_test_completion(client, MODEL_NAME, prompt, token_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using _test_completion, it would be better to use evaluate_gsm8k, which will test correctness (without flakiness)

def evaluate_gsm8k(
num_questions: int = 1319,
num_shots: int = 5,
max_tokens: int = 256,
host: str = "http://127.0.0.1",
port: int = 8000,
temperature: float = 0.0,
seed: int | None = 42,
) -> dict[str, float | int]:

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a round of comments - feel good landing this once these + @ruisearch42's are addressed

Comment on lines +143 to +149
if ep_group not in _world.pg_map:
ep_group = get_ep_group()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally agree with the bot - could we find a better way to detect this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW has this been tested on ROCm?

poller.register(socket, zmq.POLLIN)
poller.register(first_req_rcv_socket, zmq.POLLIN)

nonlocal count_slice
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of maintaining count_slice as mutable closure state with nonlocal, consider computing it on the fly from self.engine_ranks_managed. Then we have engine_ranks_managed as a single source of truth, and the slice should be cheap so no performance concerns

Comment on lines +1407 to +1417
if enable_elastic_ep:
tp_pp_pcp_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* prefill_context_model_parallel_size
)
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to handle prefill_context_parallel here? And what about decode context parallel?

Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! LGTM after addressing the outstanding comments.

WorkerType = Literal["existing", "new", "removing"]


class ScaleUpExistingEningeState(enum.IntEnum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Engine

torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=self.new_dp_group_or_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need new_dp_group_or_config that could represent either? feel it is less type safe. Is it too inconvenient to use two variables?

sched_yield()

def _staged_barrier(self, use_new_group: bool) -> bool:
# NOTE(yongji): currently we use a two-staged
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deserves more explanation here why two stages

outputs.utility_output.call_id == -1
and notification_callback_handler is not None
):
# NOTE(yongji): call_id -1 in utility_output is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use a constant? This magic number could be used in multiple places and we don't have the comment everywhere

Comment on lines +1347 to +1348
dummy_output = UtilityOutput(call_id=-1, result=UtilityResult(None))
_process_utility_output(dummy_output, self.utility_results)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment why we need to process the dummy_output?

self.data_parallel_rank,
self.data_parallel_size,
backend=current_platform.dist_backend,
backend="gloo",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not look safe? not all platforms use "gloo"

):
cache = self.eep_scaling_cache
notification_type, dp_rank = notification_data
if notification_type == "RECONFIGURE_FINISHED":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create an enum for notification type and have more checks/assertions for unexpected types? Otherwise it may be hard to debug errors

self.reqs_in_flight.pop(req_id, None)

@staticmethod
async def eep_process_worker_notification(
Copy link
Collaborator

@ruisearch42 ruisearch42 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along the same line as some of the earlier comments: we should call it core_engine as opposed to workers, which may be confused with workers of an executor.

new_data_parallel_size,
)
return
logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the draining option in case it is preferred as opposed to increased/unbound TPOT?

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Dec 2, 2025
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return a 503 Service Unavailable when capturing CUDA graphs during scaling since that will pause execution.

See:

class ScalingMiddleware:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] != "http":
return self.app(scope, receive, send)
# Check global scaling state
global _scaling_elastic_ep
if _scaling_elastic_ep:
# Return 503 Service Unavailable response
response = JSONResponse(
content={
"error": "The model is currently scaling. Please try again later."
},
status_code=503,
)
return response(scope, receive, send)
return self.app(scope, receive, send)

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to currently-running requests on DP ranks that are removed during scale-down?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

5 participants