Skip to content

Conversation

@david6666666
Copy link
Contributor

@david6666666 david6666666 commented Aug 4, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Fix the bug of #21883.

When using naive dispatch, there will be no problem with the statistics of expert workload. However, when using the pplx-kernel or DeepEP, which follow a different code path, the dispatch might not have been executed yet at that point. The load of experts routed to other nodes is not counted.

Previously, the logic for calculating expert load was to only count the load of activated experts by the current rank and aggregate the global load in eplc_state.py. The logic for modifying expert statistics in this PR is to directly count the load of all physical experts activated by each token.

After this modification, the time point of dispatch is no longer important, and load statistics can work normally when activating the pplx-kernel or DeepEP. When using naive dispatch, due to the modification of logic, each token's activated expert will be counted multiple times. However, since the load between experts is still proportional, this modification will not affect the results of the EPLB algorithm.

Test Plan

Output the sum of expert load in the function rearrange() in eplb_state.py after expert load calculation:

if not is_profile:
    print(f"The sum of global_expert_load is: {global_expert_load_window.sum().item()}")

Test with/wo pplx:

VLLM_ALL2ALL_BACKEND=pplx CUDA_VISIBLE_DEVICES=0,1 python examples/offline_inference/data_parallel.py \
    --model="/workspace/models/DeepSeek-V2-Lite" \
    --trust-remote-code \
    --dp-size=2 \
    --tp-size=1
CUDA_VISIBLE_DEVICES=0,1 python examples/offline_inference/data_parallel.py \
    --model="/workspace/models/DeepSeek-V2-Lite" \
    --trust-remote-code \
    --dp-size=2 \
    --tp-size=1

Test Result

with pplx:

The sum of global_expert_load is: 3618264

wo pplx, twice as much as using pplx when dp=2 and tp=1 (because each token's activated expert count twice) , which is expected:

The sum of global_expert_load is: 7236528

(Optional) Documentation Update

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request addresses a bug in the expert load balancing (EPLB) mechanism for MoE kernels. The change to count the load for all physical experts activated by each token aims to make the statistics independent of dispatch timing. The shape adjustments in vllm/distributed/eplb/eplb_state.py and changes in vllm/model_executor/layers/fused_moe/layer.py seem correct. However, there's a critical concern regarding the normalization of the expert load in vllm/distributed/eplb/eplb_state.py that might affect the "naive dispatch" code path.

@github-actions
Copy link

github-actions bot commented Aug 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@david6666666
Copy link
Contributor Author

@abmfy please review, thanks

Copy link
Member

@abmfy abmfy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.

Could you please confirm if my understanding aligns with yours here:

  1. Under naive all-to-all dispatch settings (i.e., performing DP without using dedicated kernels that handle DP dispatch, such as DeepEP high-throughput mode), each DP rank will contribute the same token set to the expert load. As a result, the recorded expert load will be multiplied by dp_size (this also applies in TP+DP settings).
    If this is correct, could you please add it to the comments of expert_load_window as a NOTE:? Since we plan to expose expert load metrics through interfaces, this clarification will help us divide the metrics by dp_size later, ensuring that the reported figures have a clear meaning.
  2. When using communication kernels such as DeepEP or pplx-kernels, which handle dispatch and combine within the prepare and finalize methods of the modular kernel, the case should be the same as above because the expert load is collected before the modular kernel is invoked.

Also, could you please add some additional tests (manual tests are fine) to verify correctness under different settings, such as TP and DP+TP? It would also be great to confirm that under these settings, the balancedness after rearrangement remains as good as in the previous implementation.

Thanks so much for the contribution!

@CarrotShoo
Copy link
Contributor

Overall LGTM.

Could you please confirm if my understanding aligns with yours here:

  1. Under naive all-to-all dispatch settings (i.e., performing DP without using dedicated kernels that handle DP dispatch, such as DeepEP high-throughput mode), each DP rank will contribute the same token set to the expert load. As a result, the recorded expert load will be multiplied by dp_size (this also applies in TP+DP settings).
    If this is correct, could you please add it to the comments of expert_load_window as a NOTE:? Since we plan to expose expert load metrics through interfaces, this clarification will help us divide the metrics by dp_size later, ensuring that the reported figures have a clear meaning.
  2. When using communication kernels such as DeepEP or pplx-kernels, which handle dispatch and combine within the prepare and finalize methods of the modular kernel, the case should be the same as above because the expert load is collected before the modular kernel is invoked.

Also, could you please add some additional tests (manual tests are fine) to verify correctness under different settings, such as TP and DP+TP? It would also be great to confirm that under these settings, the balancedness after rearrangement remains as good as in the previous implementation.

Thanks so much for the contribution!

Thanks for the review!

Yes, we have the same understanding. I will add the test and note soon.

@CarrotShoo
Copy link
Contributor

CarrotShoo commented Aug 5, 2025

I conducted several manual test locally, the result looks correct. Just divide the metrics by dp_size later, we can get the actual experts load.

Here are brief parameters and results:
Overall parameters in data_parallel.py:

sampling_params = SamplingParams(
    temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
)

llm = LLM(
    model=model,
    tensor_parallel_size=GPUs_per_dp_rank,
    enforce_eager=enforce_eager,
    enable_expert_parallel=True,
    enable_eplb=True,
    eplb_window_size=1000,
    eplb_step_interval=100,
    trust_remote_code=trust_remote_code,
    max_num_seqs=max_num_seqs,
    gpu_memory_utilization=gpu_memory_utilization,
)

dp=1, tp=2:

with pplx: The sum of global_expert_load is: 3607656
wo pplx: The sum of global_expert_load is: 3621696

dp=2, tp=2:

with pplx: The sum of global_expert_load is: 7458048
wo pplx: The sum of global_expert_load is: 14954784

Also output the proportion between expert loads (dp=2, tp=2):

if not is_profile:
    layer_idx = 0
    loads = global_expert_load_window[layer_idx]
    loads = loads.float()
    ratios = loads / loads[0]
    print(f"Expert load ratios (relative to expert 0) for layer {layer_idx}:")
    print([f"expert_{i}: {ratio:.3f}" for i, ratio in enumerate(ratios.tolist())])

The output of expert load ratios, which looks consistent:

previous implementation (wo pplx):
Expert load ratios (relative to expert 0) for layer 0:
['expert_0: 1.000', 'expert_1: 0.835', 'expert_2: 0.909', 'expert_3: 1.048', 'expert_4: 6.834', 'expert_5: 1.068', 'expert_6: 0.217', 'expert_7: 0.189', 'expert_8: 1.247', 'expert_9: 0.248', 'expert_10: 0.067', 'expert_11: 0.408', 'expert_12: 0.051', 'expert_13: 0.044', 'expert_14: 0.088', 'expert_15: 0.114', 'expert_16: 0.104', 'expert_17: 0.061', 'expert_18: 0.040', 'expert_19: 0.164', 'expert_20: 0.196', 'expert_21: 0.192', 'expert_22: 0.215', 'expert_23: 0.064', 'expert_24: 0.255', 'expert_25: 6.027', 'expert_26: 6.340', 'expert_27: 0.205', 'expert_28: 0.065', 'expert_29: 0.172', 'expert_30: 0.228', 'expert_31: 0.115', 'expert_32: 0.059', 'expert_33: 0.128', 'expert_34: 0.097', 'expert_35: 0.121', 'expert_36: 0.081', 'expert_37: 0.094', 'expert_38: 0.303', 'expert_39: 0.141', 'expert_40: 0.213', 'expert_41: 0.054', 'expert_42: 0.080', 'expert_43: 0.278', 'expert_44: 0.042', 'expert_45: 0.070', 'expert_46: 0.422', 'expert_47: 0.320', 'expert_48: 0.105', 'expert_49: 0.162', 'expert_50: 0.068', 'expert_51: 0.152', 'expert_52: 0.229', 'expert_53: 0.061', 'expert_54: 0.067', 'expert_55: 5.994', 'expert_56: 6.258', 'expert_57: 6.137', 'expert_58: 0.037', 'expert_59: 0.062', 'expert_60: 0.234', 'expert_61: 0.146', 'expert_62: 0.117', 'expert_63: 0.065']

with pplx:
Expert load ratios (relative to expert 0) for layer 0:
['expert_0: 1.000', 'expert_1: 0.849', 'expert_2: 0.891', 'expert_3: 1.060', 'expert_4: 6.852', 'expert_5: 1.076', 'expert_6: 0.214', 'expert_7: 0.192', 'expert_8: 1.229', 'expert_9: 0.270', 'expert_10: 0.065', 'expert_11: 0.393', 'expert_12: 0.046', 'expert_13: 0.054', 'expert_14: 0.077', 'expert_15: 0.111', 'expert_16: 0.116', 'expert_17: 0.063', 'expert_18: 0.052', 'expert_19: 0.161', 'expert_20: 0.198', 'expert_21: 0.194', 'expert_22: 0.219', 'expert_23: 0.069', 'expert_24: 0.268', 'expert_25: 6.067', 'expert_26: 6.386', 'expert_27: 0.211', 'expert_28: 0.069', 'expert_29: 0.169', 'expert_30: 0.231', 'expert_31: 0.106', 'expert_32: 0.050', 'expert_33: 0.118', 'expert_34: 0.121', 'expert_35: 0.128', 'expert_36: 0.102', 'expert_37: 0.087', 'expert_38: 0.304', 'expert_39: 0.135', 'expert_40: 0.227', 'expert_41: 0.047', 'expert_42: 0.085', 'expert_43: 0.267', 'expert_44: 0.046', 'expert_45: 0.069', 'expert_46: 0.417', 'expert_47: 0.328', 'expert_48: 0.098', 'expert_49: 0.147', 'expert_50: 0.071', 'expert_51: 0.151', 'expert_52: 0.246', 'expert_53: 0.069', 'expert_54: 0.066', 'expert_55: 6.012', 'expert_56: 6.280', 'expert_57: 6.141', 'expert_58: 0.038', 'expert_59: 0.053', 'expert_60: 0.247', 'expert_61: 0.133', 'expert_62: 0.112', 'expert_63: 0.070']

wo pplx:
Expert load ratios (relative to expert 0) for layer 0:
['expert_0: 1.000', 'expert_1: 0.834', 'expert_2: 0.906', 'expert_3: 1.054', 'expert_4: 6.849', 'expert_5: 1.074', 'expert_6: 0.213', 'expert_7: 0.184', 'expert_8: 1.242', 'expert_9: 0.255', 'expert_10: 0.064', 'expert_11: 0.403', 'expert_12: 0.055', 'expert_13: 0.044', 'expert_14: 0.095', 'expert_15: 0.113', 'expert_16: 0.106', 'expert_17: 0.063', 'expert_18: 0.046', 'expert_19: 0.167', 'expert_20: 0.195', 'expert_21: 0.191', 'expert_22: 0.220', 'expert_23: 0.069', 'expert_24: 0.261', 'expert_25: 6.044', 'expert_26: 6.367', 'expert_27: 0.206', 'expert_28: 0.059', 'expert_29: 0.168', 'expert_30: 0.233', 'expert_31: 0.117', 'expert_32: 0.056', 'expert_33: 0.127', 'expert_34: 0.095', 'expert_35: 0.124', 'expert_36: 0.084', 'expert_37: 0.094', 'expert_38: 0.308', 'expert_39: 0.137', 'expert_40: 0.219', 'expert_41: 0.055', 'expert_42: 0.083', 'expert_43: 0.277', 'expert_44: 0.042', 'expert_45: 0.069', 'expert_46: 0.428', 'expert_47: 0.323', 'expert_48: 0.099', 'expert_49: 0.167', 'expert_50: 0.070', 'expert_51: 0.154', 'expert_52: 0.226', 'expert_53: 0.061', 'expert_54: 0.072', 'expert_55: 6.004', 'expert_56: 6.272', 'expert_57: 6.153', 'expert_58: 0.043', 'expert_59: 0.060', 'expert_60: 0.232', 'expert_61: 0.145', 'expert_62: 0.114', 'expert_63: 0.066']

@CarrotShoo
Copy link
Contributor

@abmfy Notes and tests have been added, please review, thanks!

@david6666666
Copy link
Contributor Author

@DarkLight1337 @hmellor PTAL, thank you.

@mergify
Copy link

mergify bot commented Aug 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @david6666666.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

CarrotShoo and others added 2 commits August 6, 2025 12:55
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
@tlrmchlsmth
Copy link
Member

Thanks for tracking this down 🙏

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 7, 2025 01:41
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 7, 2025
@DarkLight1337 DarkLight1337 merged commit 4be02a3 into vllm-project:main Aug 7, 2025
61 checks passed
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
Co-authored-by: ycyaw66 <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
Co-authored-by: ycyaw66 <[email protected]>
Signed-off-by: Noam Gat <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
Co-authored-by: ycyaw66 <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
Co-authored-by: ycyaw66 <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
Co-authored-by: ycyaw66 <[email protected]>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
Co-authored-by: ycyaw66 <[email protected]>
Signed-off-by: Xiao Yu <[email protected]>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ycyaw66 <[email protected]>
Signed-off-by: David Chen <[email protected]>
Co-authored-by: ycyaw66 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants