Skip to content

Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput)#7667

Merged
ch-wan merged 1 commit intosgl-project:mainfrom
trevor-m:moe-comm
Aug 16, 2025
Merged

Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput)#7667
ch-wan merged 1 commit intosgl-project:mainfrom
trevor-m:moe-comm

Conversation

@trevor-m
Copy link
Copy Markdown
Collaborator

@trevor-m trevor-m commented Jun 30, 2025

Motivation

The goal of this PR is to optimize communications for DP with FlashInfer Cutlass MoE.

Modifications

Improvements include:

  1. Add Allgatherv collective. This is a pynccl implementation of TRT-LLM's allgather which supports varying sizes per rank and a list of tensors as inputs
  2. Add reducescatterv collective. This is a pynccl implementation of TRT-LLM's reducescatter which supports varying sizes per rank
  3. For Flashinfer MoE with DP, use allgatherv to dispatch tokens. We also move the fp4 quantize before the allgather so the communication is smaller. Finally, we use reducescatterv to combine the results instead of all_reduce.

Usage

Enabled automatically when applicable: --enable-flashinfer-cutlass-moe, --enable-dp-attention, and dp_size == ep_size must all be true.
Can be disabled with --disable-flashinfer-cutlass-moe-fp4-allgather.

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-cutlass-moe --ep-size 8 --dp 8 --enable-dp-attention

Results

Accuracy

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port=30000
Accuracy: 0.958
Invalid: 0.000
Latency: 23.484 s
Output throughput: 6228.185 token/s

Benchmark

End to end speedup: 9.38%

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1024 --random-input 1024 --random-output 1024 --random-range-ratio 1 --max-concurrency 1024

BEFORE

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 1024
Successful requests:                     1024
Benchmark duration (s):                  75.54
Total input tokens:                      1048576
Total generated tokens:                  1048576
Total generated tokens (retokenized):    1045749
Request throughput (req/s):              13.56
Input token throughput (tok/s):          13881.54
Output token throughput (tok/s):         13881.54
Total token throughput (tok/s):          27763.09
Concurrency:                             1021.04
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   75318.71
Median E2E Latency (ms):                 75293.46
---------------Time to First Token----------------
Mean TTFT (ms):                          11651.64
Median TTFT (ms):                        11573.21
P99 TTFT (ms):                           20818.61
---------------Inter-Token Latency----------------
Mean ITL (ms):                           62.24
Median ITL (ms):                         53.41
P95 ITL (ms):                            66.44
P99 ITL (ms):                            72.01
Max ITL (ms):                            18181.14
==================================================

AFTER

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 1024
Successful requests:                     1024
Benchmark duration (s):                  69.06
Total input tokens:                      1048576
Total generated tokens:                  1048576
Total generated tokens (retokenized):    1045519
Request throughput (req/s):              14.83
Input token throughput (tok/s):          15183.59
Output token throughput (tok/s):         15183.59
Total token throughput (tok/s):          30367.17
Concurrency:                             1021.50
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   68891.28
Median E2E Latency (ms):                 68904.07
---------------Time to First Token----------------
Mean TTFT (ms):                          10103.61
Median TTFT (ms):                        10065.93
P99 TTFT (ms):                           17769.58
---------------Inter-Token Latency----------------
Mean ITL (ms):                           57.47
Median ITL (ms):                         49.91
P95 ITL (ms):                            61.59
P99 ITL (ms):                            68.18
Max ITL (ms):                            15189.56
==================================================

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @trevor-m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing communication patterns for Data Parallel (DP) operations within FlashInfer Mixture-of-Experts (MoE) models. My primary goal is to enhance efficiency by introducing and integrating new variable-sized collective communication primitives, allgatherv and reducescatterv, into the MoE forward pass. A key aspect of this work involves quantizing data prior to communication to minimize data transfer overhead.

Highlights

  • New NCCL Collectives: I've introduced allgatherv and reducescatterv to the PyNCCL communicator. These new primitives are designed to efficiently handle collective communication for tensors where the size of data contributed by each rank can vary, mimicking TRT-LLM's approach.
  • Optimized MoE Communication for DP: For FlashInfer Mixture-of-Experts (MoE) in a Data Parallel (DP) setup, I've refactored the communication path. Instead of the previous pad+allreduce for token dispatch, we now utilize the new allgatherv. Similarly, reducescatterv is used to combine results, replacing the prior all_reduce operation.
  • Quantization Before Communication: A significant optimization implemented is moving the FP4 quantization step to occur before the allgatherv communication. This reduces the size of the data being transferred across the network, aiming for improved communication efficiency.
  • Dynamic MoE Layer Handling: I've added logic to allow the MoE layer itself to manage dispatch and combine operations under specific configurations (e.g., when FlashInfer MoE is enabled with Data Parallelism), bypassing the general LayerCommunicator for more specialized handling.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces allgatherv and reducescatterv collectives to optimize MoE communication with data parallelism, and moves quantization before communication to reduce data volume. The changes are well-structured, adding the new communication primitives and integrating them into the MoE layers. However, there's a potential bug in the quantization logic, an incorrect return type hint, and missing tests for the new functionalities.

@Alcanderian
Copy link
Copy Markdown
Collaborator

Does varying sizes per rank allgather compatible with cuda graph?

@trevor-m
Copy link
Copy Markdown
Collaborator Author

trevor-m commented Jul 1, 2025

Does varying sizes per rank allgather compatible with cuda graph?

@Alcanderian Yes it is working with cuda graph. Although I am not super familiar with cuda graph in torch, is it possible that certain operations are excluded from the graph if they aren't compatible?

@trevor-m trevor-m force-pushed the moe-comm branch 2 times, most recently from 3252d2f to d0e835f Compare July 3, 2025 22:14
@trevor-m trevor-m requested a review from xiezhq-hermann as a code owner July 3, 2025 23:21
@trevor-m trevor-m changed the title Draft: Allgatherv+reducescatterv for Flashinfer MoE DP, quantize before comms Add --enable-flashinfer-fp4-allgather for Flashinfer MoE DP Jul 3, 2025
@Alcanderian
Copy link
Copy Markdown
Collaborator

Hi, please resolve conflicts and provide some performance report. Thanks!

@trevor-m
Copy link
Copy Markdown
Collaborator Author

trevor-m commented Jul 7, 2025

Hi, please resolve conflicts and provide some performance report. Thanks!

Hi @Alcanderian I updated the PR description with performance results - is there any more benchmarks I should run? Thanks!

@Alcanderian
Copy link
Copy Markdown
Collaborator

Hi, please resolve conflicts and provide some performance report. Thanks!

Hi @Alcanderian I updated the PR description with performance results - is there any more benchmarks I should run? Thanks!

Thanks! IMO we should disable it in decode stage accroding to the benchmark result

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the needed fields? Passing the whole forward_batch makes this function kind of opaque.
We want the function to be more explicit about the argument.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @merrymercy thank you for reviewing. In modelopt_quant.py, we use forward_batch.dp_padding_mode.is_max_len(), forward_batch.input_ids.shape[0], forward_batch.gathered_buffer, forward_batch.global_num_tokens_cpu.

Other backends like deepep which also integrate communication pass the whole forward_batch to MOE also:

forward_batch=forward_batch,

@trevor-m trevor-m force-pushed the moe-comm branch 2 times, most recently from 6a9b7ae to bd12687 Compare August 6, 2025 23:39
@trevor-m trevor-m changed the title Add --enable-flashinfer-fp4-allgather for Flashinfer cutlass MoE DP (max throughput) Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) Aug 7, 2025
@kushanam
Copy link
Copy Markdown
Collaborator

kushanam commented Aug 8, 2025

@ch-wan could you help review/merge this PR plz?

@trevor-m trevor-m requested a review from Edwardf0t1 as a code owner August 11, 2025 19:59
@trevor-m trevor-m force-pushed the moe-comm branch 2 times, most recently from fdaeec7 to e7980df Compare August 12, 2025 00:06
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code logic is getting increasingly messy. It seems that forward_batch is only need for dp communication. We need to refactor dp_attention.py so that any part of the code can handle dp communication without using forward_batch. Let me try this later.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ch-wan - anything I can do for this PR to help?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will work on this: #9136

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should explicitly list all the required arguments instead of passing a big forward_batch.
Passing a big forward_batch makes the input of this function very opaque.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Aug 14, 2025

@trevor-m DP attn refactor is almost ready. I am doing extra local tests. Could you update this PR by using the util function developed in #9136? This should make the code much cleaner. Thanks.

UPD: I've just noticed that you did some refactor in this PR :) How about adding get_dp_global_buffer_len to DP utils? This can save args in your code.

@trevor-m
Copy link
Copy Markdown
Collaborator Author

@trevor-m DP attn refactor is almost ready. I am doing extra local tests. Could you update this PR by using the util function developed in #9136? This should make the code much cleaner. Thanks.

UPD: I've just noticed that you did some refactor in this PR :) How about adding get_dp_global_buffer_len to DP utils? This can save args in your code.

get_dp_global_buffer_len

Thanks @ch-wan I added a helper so I don't need to pass forward_batch.global_num_tokens_cpu anymore. Please take a look

Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I only have some minor comments. Could you fix conflicts? Thanks.

Flashinfer MoE FP4 communication optimization for DP

Fix quantize before comm and only use for dp, use numel(), fix return type annotatin

Use enable_flashinfer_fp4_allgather to toggle

lint

Fix forwardbatch

Automatically enable fp4 allgather

Enable automatically and improve server arg descriptions

Switch from server arg to should_use_flashinfer_cutlass_moe_fp4_allgather(). Add server arg to disable

formatting

Remove forward_batch arg

Use helper function for empty topk

Add get_dp_global_num_tokens() helper
@ch-wan ch-wan merged commit eff4eb3 into sgl-project:main Aug 16, 2025
120 of 132 checks passed
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset : split_offset + split_size]
self.nccl.ncclBroadcast(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this broadcast or allgather?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s equivalent to all gather. Each rank does a broadcast but we group them to avoid overheads. This is done to allow each rank to have a different size

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants