Skip to content

[feat] Support tp mode for DeepSeek-R1-W4AFP8#8118

Merged
zhyncs merged 11 commits intosgl-project:mainfrom
chenxijun1029:feat/w4afp8-tp
Sep 2, 2025
Merged

[feat] Support tp mode for DeepSeek-R1-W4AFP8#8118
zhyncs merged 11 commits intosgl-project:mainfrom
chenxijun1029:feat/w4afp8-tp

Conversation

@chenxijun1029
Copy link
Copy Markdown
Contributor

@chenxijun1029 chenxijun1029 commented Jul 17, 2025

Motivation

Support tp mode for DeepSeek w4a8 model, which has a better performace than ep mode.

Modifications

  1. Add W4AFp8MoEMethod and associated create_weights, process_weights_after_loading function and apply function. In the apply function, we use the same cutlass_w4a8_moe kernel as ep moe uses.
  2. Add some tile shape and cluster shape config for tp moe in cutlass_w4a8_moe kernel.
  3. Add a router logic in w4afp8 quant config and method. When "enable_ep_moe" found in global_server_args_dict, we use ep mode, else tp.

Co-author: @yuhyao 827623970@qq.com

Benchmark

We run DeepSeek-R1-W4AFP8 on 8H20 with tp8, comparing to run DeepSeek-R1 on 8H20 with ep8.
Test configuration:

ISL1000, OSL1000

input/output len = 1000/1000, qps=128, max_concurrency=128, num_prompt=256.
The results are shown below:

TP

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    128.0     
Max request concurrency:                 128       
Successful requests:                     256       
Benchmark duration (s):                  159.00    
Total input tokens:                      256000    
Total generated tokens:                  256000    
Total generated tokens (retokenized):    254696    
Request throughput (req/s):              1.61      
Input token throughput (tok/s):          1610.09   
Output token throughput (tok/s):         1610.09   
Total token throughput (tok/s):          3220.18   
Concurrency:                             127.52    
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   79201.47  
Median E2E Latency (ms):                 78956.21  
---------------Time to First Token----------------
Mean TTFT (ms):                          6547.98   
Median TTFT (ms):                        6612.11   
P99 TTFT (ms):                           11687.82  
---------------Inter-Token Latency----------------
Mean ITL (ms):                           72.73     
Median ITL (ms):                         68.05     
P95 ITL (ms):                            72.42     
P99 ITL (ms):                            73.05     
Max ITL (ms):                            11148.65  
==================================================

While EP:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    128.0     
Max request concurrency:                 128       
Successful requests:                     256       
Benchmark duration (s):                  161.37    
Total input tokens:                      256000    
Total generated tokens:                  256000    
Total generated tokens (retokenized):    255343    
Request throughput (req/s):              1.59      
Input token throughput (tok/s):          1586.41   
Output token throughput (tok/s):         1586.41   
Total token throughput (tok/s):          3172.81   
Concurrency:                             127.53    
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   80390.85  
Median E2E Latency (ms):                 80644.43  
---------------Time to First Token----------------
Mean TTFT (ms):                          8143.41   
Median TTFT (ms):                        8144.56   
P99 TTFT (ms):                           14833.05  
---------------Inter-Token Latency----------------
Mean ITL (ms):                           72.32     
Median ITL (ms):                         66.38     
P95 ITL (ms):                            71.43     
P99 ITL (ms):                            72.08     
Max ITL (ms):                            14113.65  
==================================================

ISL6000, OSL1000

input/output len = 6000/1000, qps=128, max_concurrency=128, num_prompt=256.
The results are shown below:

TP

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    128.0     
Max request concurrency:                 128       
Successful requests:                     256       
Benchmark duration (s):                  525.30    
Total input tokens:                      1536000   
Total generated tokens:                  256000    
Total generated tokens (retokenized):    254578    
Request throughput (req/s):              0.49      
Input token throughput (tok/s):          2924.03   
Output token throughput (tok/s):         487.34    
Total token throughput (tok/s):          3411.37   
Concurrency:                             109.09    
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   223849.73 
Median E2E Latency (ms):                 251328.94 
---------------Time to First Token----------------
Mean TTFT (ms):                          118601.26 
Median TTFT (ms):                        149225.11 
P99 TTFT (ms):                           183862.13 
---------------Inter-Token Latency----------------
Mean ITL (ms):                           105.36    
Median ITL (ms):                         83.10     
P95 ITL (ms):                            85.79     
P99 ITL (ms):                            86.27     
Max ITL (ms):                            52034.42  
==================================================

While EP:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    128.0     
Max request concurrency:                 128       
Successful requests:                     256       
Benchmark duration (s):                  592.58    
Total input tokens:                      1536000   
Total generated tokens:                  256000    
Total generated tokens (retokenized):    255781    
Request throughput (req/s):              0.43      
Input token throughput (tok/s):          2592.07   
Output token throughput (tok/s):         432.01    
Total token throughput (tok/s):          3024.08   
Concurrency:                             108.41    
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   250942.46 
Median E2E Latency (ms):                 282080.60 
---------------Time to First Token----------------
Mean TTFT (ms):                          137965.32 
Median TTFT (ms):                        171540.14 
P99 TTFT (ms):                           217855.50 
---------------Inter-Token Latency----------------
Mean ITL (ms):                           113.09    
Median ITL (ms):                         81.36     
P95 ITL (ms):                            84.83     
P99 ITL (ms):                            85.55     
Max ITL (ms):                            71910.62  
==================================================

We conclude the results as the below sheet:

Scenario Version ITL (ms) TTFT (ms) Request throughput (req/s)
1K input,1K output,128 request rate EP 72.32 8 143.41 1.59
TP 72.73 (+0.6 %) 6 547.98 (−19.60 %) 1.61 (+1.2 %)
6K input,1K output,128 request rate EP 113.09 137 965.32 0.43
TP 105.36 (−6.83 %) 118 601.26 (−14.03 %) 0.49 (+14.0 %)

#Accuracy
mmlu: 86.9 (by @yuhyao
aime24: 80.0
gpqa: 71.2
math500: 95.6

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @chenxijun1029, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the DeepSeek-R1-W4AFP8 model's performance by introducing support for Tensor Parallelism (TP) mode. It involves a comprehensive update to the quantization and MoE layer implementations, including new weight handling, dynamic method selection, and optimized low-level kernel configurations, ultimately aiming for more efficient model serving.

Highlights

  • Tensor Parallelism for DeepSeek-R1-W4AFP8: Introduced a new W4AFp8TPMoEMethod to enable Tensor Parallelism (TP) for DeepSeek-R1-W4AFP8 models, demonstrating improved performance compared to Expert Parallelism (EP) mode.
  • Dynamic MoE Quantization Method Selection: Implemented a routing mechanism that dynamically selects between Expert Parallelism (EP) and Tensor Parallelism (TP) quantization methods for MoE layers based on the enable_ep_moe global server argument.
  • Optimized CUTLASS Kernel Configurations: Added new tile and cluster shape configurations within the cutlass_w4a8_moe kernel and extended its dispatch logic to provide optimized performance for the specific matrix dimensions encountered in Tensor Parallelism mode.
  • Enhanced Weight Processing and Loading: Updated the weight creation and processing logic for TP MoE, including refactoring scale interleaving into a shared utility function and adding support for special naming rules for input scales in mixed-precision models.
  • Shared Experts Fusion Disablement: Explicitly disabled shared experts fusion optimization for W4A8 TP MoE models, as their quantization methods differ between routed and shared experts.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for tensor parallelism (TP) mode for the DeepSeek-R1-W4AFP8 model, which shows performance improvements over the existing expert parallelism (EP) mode. The changes include adding a new W4AFp8TPMoEMethod for quantization, updating the MoE kernel configurations, and adding logic to switch between TP and EP modes. The implementation looks solid and the benchmark results are promising. I've added a couple of comments to improve code maintainability by reducing duplication in both the Python and CUDA C++ code. These changes should not affect performance but will make the code easier to read and maintain.

Comment on lines +195 to +320
} else if (n == 512 && k == 7168) {
// group gemm 1 for tp
if (m <= 4) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 256) {
// group gemm 2 for tp
if (m <= 8) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 256, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 512) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(128, 32, 256, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(128, 64, 256, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication in the new else if blocks for n == 512 && k == 7168 and n == 7168 && k == 256. The only difference between the if (m <= ...) blocks is the Cutlass3xW4A8GemmSelected type.

To improve maintainability, you could refactor this using a helper template function or a macro to reduce the repeated calls to cutlass_w4a8_group_gemm_caller.

For example:

template <typename Gemm>
void invoke_cutlass_caller(...) {
    cutlass_w4a8_group_gemm_caller<Gemm>(...);
}

// In dispatch_w4a8_moe_mm_sm90:
if (m <= 4) {
    using Gemm = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
    invoke_cutlass_caller<Gemm>(...);
} else if (m <= 16) {
    using Gemm = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
    invoke_cutlass_caller<Gemm>(...);
}
// ...

While I understand this pattern is common for performance-critical CUDA code to help the compiler generate specialized code, this refactoring would make the code much cleaner without a performance penalty.

@AniZpZ
Copy link
Copy Markdown
Collaborator

AniZpZ commented Jul 17, 2025

Thanks for your great work on this! To help us evaluate the impact of this PR, could you please provide the performance results (like GSM8K, MMLU, and Hellaswag)

layer.register_parameter("w2_weight_scale_inv", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)

# The input scale for w1 and w3 should be the same
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to confirm if you've checked the contents of the act_scales.safetensors file. Are the input scales for w1 and w3 all consistent?

@chenxijun1029 chenxijun1029 requested a review from kushanam as a code owner July 31, 2025 02:41
@yuhyao
Copy link
Copy Markdown
Contributor

yuhyao commented Aug 5, 2025

@chenxijun1029 Nice work! Just wondering if there will be any further updates?
Also, should the file sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py be updated as well?

@zhilingjiang
Copy link
Copy Markdown

Nice work!

Co-authored-by: yuhyao <827623970@qq.com>
@chenxijun1029
Copy link
Copy Markdown
Contributor Author

Thanks for your great work on this! To help us evaluate the impact of this PR, could you please provide the performance results (like GSM8K, MMLU, and Hellaswag)

Hi, I have updated the accuracy in the pr, please check it.

@AniZpZ
Copy link
Copy Markdown
Collaborator

AniZpZ commented Aug 20, 2025

LGTM, do you have any more sugguestion? @yangsijia-serena

@yangsijia-serena
Copy link
Copy Markdown
Collaborator

LGTM, do you have any more sugguestion? @yangsijia-serena

LGTM!

@yuhyao yuhyao mentioned this pull request Aug 20, 2025
4 tasks
@zhyncs zhyncs merged commit d4a9384 into sgl-project:main Sep 2, 2025
105 of 113 checks passed
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
"compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
or "w4afp8" in self.quant_config.get_name()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It broken compressed-tensor format moe models like neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 , fixed in #10299 .

@Bruce-x-1997
Copy link
Copy Markdown
Contributor

hello, thanks for the pr, but when I use w4afp8 in deepseek-r1(https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8) and test it in aime24, it could only get 10% score
so what's your test model, which reach 80% in aime24
@chenxijun1029

@chenxijun1029
Copy link
Copy Markdown
Contributor Author

hello, thanks for the pr, but when I use w4afp8 in deepseek-r1(https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8) and test it in aime24, it could only get 10% score so what's your test model, which reach 80% in aime24 @chenxijun1029

I tested it with evalscope, and you should set "max_num_tokens" larger, i.e. 20000, for better performance.

@Bruce-x-1997
Copy link
Copy Markdown
Contributor

Bruce-x-1997 commented Sep 17, 2025

hello, thanks for the pr, but when I use w4afp8 in deepseek-r1(https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8) and test it in aime24, it could only get 10% score so what's your test model, which reach 80% in aime24 @chenxijun1029

I tested it with evalscope, and you should set "max_num_tokens" larger, i.e. 20000, for better performance.

thanks, I have reached 80% in r1 & r1-0528(quant use modelopt)
but I found if I use the same method to v3.1 , v3.1 accuracy could not be accepted, in v3.1(10% in aime24), and if we use a question in aime24 to test v3.1(after quant, and batch is 1), its answer is wrong.
do you find similar problems?is there anything we can try? @chenxijun1029

@chenxijun1029
Copy link
Copy Markdown
Contributor Author

hello, thanks for the pr, but when I use w4afp8 in deepseek-r1(https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8) and test it in aime24, it could only get 10% score so what's your test model, which reach 80% in aime24 @chenxijun1029

I tested it with evalscope, and you should set "max_num_tokens" larger, i.e. 20000, for better performance.

thanks, I have reached 80% in r1 & r1-0528(quant use modelopt) but I found if I use the same method to v3.1 , v3.1 accuracy could not be accepted, in v3.1(10% in aime24), and if we use a question in aime24 to test v3.1(after quant, and batch is 1), its answer is wrong. do you find similar problems?is there anything we can try? @chenxijun1029

Can we share more information via wechat?

@llc-kc
Copy link
Copy Markdown
Contributor

llc-kc commented Nov 13, 2025

@chenxijun1029 Hi, are the TP and EP results evaluated by W4AFP8 and FP8 respectively? Do you compare W4AFP8 with the original FP8 both in TP mode? Why does W4AFP8 significantly reduce the weight size but seem not to improve ITL? Thanks.

@llc-kc
Copy link
Copy Markdown
Contributor

llc-kc commented Nov 13, 2025

@chenxijun1029 do we need an extra launch parameter --quantization w4afp8?

@junliu-mde
Copy link
Copy Markdown
Contributor

@chenxijun1029 Hi, are the TP and EP results evaluated by W4AFP8 and FP8 respectively? Do you compare W4AFP8 with the original FP8 both in TP mode? Why does W4AFP8 significantly reduce the weight size but seem not to improve ITL? Thanks.

I guess it's the dequant process has non-negligible overhead. There should be space for improvement in kernels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.