[feat] Support tp mode for DeepSeek-R1-W4AFP8#8118
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @chenxijun1029, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the DeepSeek-R1-W4AFP8 model's performance by introducing support for Tensor Parallelism (TP) mode. It involves a comprehensive update to the quantization and MoE layer implementations, including new weight handling, dynamic method selection, and optimized low-level kernel configurations, ultimately aiming for more efficient model serving.
Highlights
- Tensor Parallelism for DeepSeek-R1-W4AFP8: Introduced a new
W4AFp8TPMoEMethodto enable Tensor Parallelism (TP) for DeepSeek-R1-W4AFP8 models, demonstrating improved performance compared to Expert Parallelism (EP) mode. - Dynamic MoE Quantization Method Selection: Implemented a routing mechanism that dynamically selects between Expert Parallelism (EP) and Tensor Parallelism (TP) quantization methods for MoE layers based on the
enable_ep_moeglobal server argument. - Optimized CUTLASS Kernel Configurations: Added new tile and cluster shape configurations within the
cutlass_w4a8_moekernel and extended its dispatch logic to provide optimized performance for the specific matrix dimensions encountered in Tensor Parallelism mode. - Enhanced Weight Processing and Loading: Updated the weight creation and processing logic for TP MoE, including refactoring scale interleaving into a shared utility function and adding support for special naming rules for input scales in mixed-precision models.
- Shared Experts Fusion Disablement: Explicitly disabled shared experts fusion optimization for W4A8 TP MoE models, as their quantization methods differ between routed and shared experts.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds support for tensor parallelism (TP) mode for the DeepSeek-R1-W4AFP8 model, which shows performance improvements over the existing expert parallelism (EP) mode. The changes include adding a new W4AFp8TPMoEMethod for quantization, updating the MoE kernel configurations, and adding logic to switch between TP and EP modes. The implementation looks solid and the benchmark results are promising. I've added a couple of comments to improve code maintainability by reducing duplication in both the Python and CUDA C++ code. These changes should not affect performance but will make the code easier to read and maintain.
| } else if (n == 512 && k == 7168) { | ||
| // group gemm 1 for tp | ||
| if (m <= 4) { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } else if (m <= 16) { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } else if (m <= 256) { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } else if (m <= 1024) { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } else { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } | ||
| } else if (n == 7168 && k == 256) { | ||
| // group gemm 2 for tp | ||
| if (m <= 8) { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 256, 1, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } else if (m <= 512) { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(128, 32, 256, 2, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } else { | ||
| using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(128, 64, 256, 1, 1, 1)::Cutlass3xW4A8Gemm; | ||
| cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>( | ||
| d_tensors, | ||
| a_tensors, | ||
| b_tensors, | ||
| a_scales, | ||
| b_scales, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| a_strides, | ||
| b_strides, | ||
| d_strides, | ||
| s_strides, | ||
| chunk_size); | ||
| } |
There was a problem hiding this comment.
There's significant code duplication in the new else if blocks for n == 512 && k == 7168 and n == 7168 && k == 256. The only difference between the if (m <= ...) blocks is the Cutlass3xW4A8GemmSelected type.
To improve maintainability, you could refactor this using a helper template function or a macro to reduce the repeated calls to cutlass_w4a8_group_gemm_caller.
For example:
template <typename Gemm>
void invoke_cutlass_caller(...) {
cutlass_w4a8_group_gemm_caller<Gemm>(...);
}
// In dispatch_w4a8_moe_mm_sm90:
if (m <= 4) {
using Gemm = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
invoke_cutlass_caller<Gemm>(...);
} else if (m <= 16) {
using Gemm = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
invoke_cutlass_caller<Gemm>(...);
}
// ...While I understand this pattern is common for performance-critical CUDA code to help the compiler generate specialized code, this refactoring would make the code much cleaner without a performance penalty.
|
Thanks for your great work on this! To help us evaluate the impact of this PR, could you please provide the performance results (like GSM8K, MMLU, and Hellaswag) |
| layer.register_parameter("w2_weight_scale_inv", w2_scales) | ||
| set_weight_attrs(w2_scales, extra_weight_attrs) | ||
|
|
||
| # The input scale for w1 and w3 should be the same |
There was a problem hiding this comment.
just want to confirm if you've checked the contents of the act_scales.safetensors file. Are the input scales for w1 and w3 all consistent?
|
@chenxijun1029 Nice work! Just wondering if there will be any further updates? |
|
Nice work! |
Co-authored-by: yuhyao <827623970@qq.com>
b0f62e8 to
5bee9f3
Compare
Fix bugs caused by rebasing.
Hi, I have updated the accuracy in the pr, please check it. |
f0a79a2 to
756540e
Compare
|
LGTM, do you have any more sugguestion? @yangsijia-serena |
LGTM! |
Co-authored-by: yuhyao <827623970@qq.com>
| "compressed" in self.quant_method.__class__.__name__.lower() | ||
| and param.data[expert_id] != 1 | ||
| and (param.data[expert_id] - loaded_weight).abs() > 1e-5 | ||
| or "w4afp8" in self.quant_config.get_name() |
There was a problem hiding this comment.
It broken compressed-tensor format moe models like neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 , fixed in #10299 .
|
hello, thanks for the pr, but when I use w4afp8 in deepseek-r1(https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8) and test it in aime24, it could only get 10% score |
I tested it with evalscope, and you should set "max_num_tokens" larger, i.e. 20000, for better performance. |
thanks, I have reached 80% in r1 & r1-0528(quant use modelopt) |
Can we share more information via wechat? |
|
@chenxijun1029 Hi, are the TP and EP results evaluated by W4AFP8 and FP8 respectively? Do you compare W4AFP8 with the original FP8 both in TP mode? Why does W4AFP8 significantly reduce the weight size but seem not to improve ITL? Thanks. |
|
@chenxijun1029 do we need an extra launch parameter --quantization w4afp8? |
I guess it's the dequant process has non-negligible overhead. There should be space for improvement in kernels. |
Motivation
Support tp mode for DeepSeek w4a8 model, which has a better performace than ep mode.
Modifications
create_weights,process_weights_after_loadingfunction andapplyfunction. In the apply function, we use the same cutlass_w4a8_moe kernel as ep moe uses.cutlass_w4a8_moekernel.Co-author: @yuhyao 827623970@qq.com
Benchmark
We run DeepSeek-R1-W4AFP8 on 8H20 with tp8, comparing to run DeepSeek-R1 on 8H20 with ep8.
Test configuration:
ISL1000, OSL1000
input/output len = 1000/1000, qps=128, max_concurrency=128, num_prompt=256.
The results are shown below:
TP
While EP:
ISL6000, OSL1000
input/output len = 6000/1000, qps=128, max_concurrency=128, num_prompt=256.
The results are shown below:
TP
While EP:
We conclude the results as the below sheet:
#Accuracy
mmlu: 86.9 (by @yuhyao )
aime24: 80.0
gpqa: 71.2
math500: 95.6
Checklist
cutlass_w4a8_moekernel already exsited (https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci).