Skip to content

Implement per-token w4afp8 moe gemm, improve performance with w4afp8 moe gemm#21101

Open
Wangzheee wants to merge 1 commit intosgl-project:mainfrom
Wangzheee:w4a8_per-token_kernel
Open

Implement per-token w4afp8 moe gemm, improve performance with w4afp8 moe gemm#21101
Wangzheee wants to merge 1 commit intosgl-project:mainfrom
Wangzheee:w4a8_per-token_kernel

Conversation

@Wangzheee
Copy link
Copy Markdown

@Wangzheee Wangzheee commented Mar 21, 2026

Motivation

[1/2] Enhance w4afp8 performance: The kernel part of the complete code functionality

This pull request (PR) refactors the functionality of PR-7762 Significantly enhance performance:

  1. Resolved the question w4afp8 deepep is that it can only use bf16
  • Reimplemented the gemm pipeline with per-token quantization granularity, enabling dispatch for fp8 tokens
  1. Improve performance with w4afp8 moe gemm
  • Add a pipeline for A scale (per-token)

  • Optimize the pipeline

  • Replace LDS with LDSM for W4 weight

  • Estimate the shape of the actual activation in MOE, to optimize block tile shape

How to use w4afp8

  1. merge PR
  • to sglang develop
  1. use w4afp8 model

Modifications

  • w4afp8 moe gemm kernel
  • moe layer
  • test

Accuracy Tests

model: https://huggingface.co/deepseek-ai/DeepSeek-R1 https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8

Data Sets fp8 Wint4Afp8(previous) Wint4Afp8(per-token optimize)
gsm8k,02f245,accuracy,gen 96.21 96.13 96.13
ceval,-,naive_average,gen 92.50 92.15 92.31
mmlu,-,naive_average,gen 91.49 91.61 91.70

Benchmarking and Profiling

Summary

w4afp8(per-token optimize) VS fp8 same as w4afp8(per-token optimize) VS w4afp8(previous)

  1. EP(ep-size=8)

    batch-size performance improvement with MoE GEMM performance improvement with TPOT
    32 90% 50% ~ 60%
    64 80% 50%
  2. TP(tp-size=8)

    batch-size performance improvement with MoE GEMM performance improvement with TPOT
    32 60% 5%
    64 50% 5%

Detail data

  • Max request concurrency: 32
  • Successful requests: 300
  • Total input tokens: 1,228,500
  • Total input text tokens: 1,228,500
  • Total generated tokens: 307,200
  • ep-size: 8
fp8
- Wint4Afp8
(sglang previous)
- Wint4Afp8
(per-token)
(pipeline optimized)
(tile shape optimized)
-
dp-size=1 dp-size=2 dp-size=1 dp-size=2 dp-size=1 dp-size=2
Request throughput (req/s) 0.42 0.47 0.43 0.47 0.60 0.69
Input token throughput (tok/s) 1737 1906 1747 1924 2453 2821
Output token throughput (tok/s) 434 476 436 481 613 706
Total token throughput (tok/s) 2172 2382 2184 2405 3066 3527
Mean E2E Latency (ms) 72335 65756 71608 64868 51170 44294
Mean ITL (ms) 64.94 58.82 63.68 57.45 43.02 36.44
Mean TTFT (ms) 5899 5588 6460 6096 7161 7020
Mean TPOT (ms) 64.94 58.81 63.68 57.45 43.02 36.44
  • Max request concurrency: 64
  • Successful requests: 300
  • Total input tokens: 1,228,500
  • Total input text tokens: 1,228,500
  • Total generated tokens: 307,200
  • ep-size: 8
fp8
- Wint4Afp8
(sglang previous)
- Wint4Afp8
(per-token)
(pipeline optimized)
(tile shape optimized)
-
dp-size=1 dp-size=2 dp-size=1 dp-size=2 dp-size=1 dp-size=2
Request throughput (req/s) 0.43 0.68 0.43 0.71 0.60 0.93
Input token throughput (tok/s) 1769 2785 1771 2894 2460 3825
Output token throughput (tok/s) 442 696 443 723 615 956
Total token throughput (tok/s) 2212 3481 2214 3618 3075 4781
Mean E2E Latency (ms) 136359 87896 134694 84463 97293 63847
Mean ITL (ms) 68.70 76.53 67.26 72.44 46.26 51.15
Mean TTFT (ms) 66077 9603 65886 10357 49970 11524
Mean TPOT (ms) 68.70 76.53 67.26 72.44 46.26 51.15
  • Max request concurrency: 32
  • Successful requests: 300
  • Total input tokens: 1,228,500
  • Total input text tokens: 1,228,500
  • Total generated tokens: 307,200
  • tp-size: 8
fp8
- Wint4Afp8
(sglang previous)
- Wint4Afp8
(per-token)
(pipeline optimized)
(tile shape optimized)
-
dp-size=1 dp-size=2 dp-size=1 dp-size=2 dp-size=1 dp-size=2
Request throughput (req/s) 0.64 0.68 0.52 0.54 0.67 0.71
Input token throughput (tok/s) 2600 2769 2118 2230 2745 2906
Output token throughput (tok/s) 650 692 530 557 686 727
Total token throughput (tok/s) 3251 3461 2648 2787 3432 3633
Mean E2E Latency (ms) 48530 45207 59607 56296 45851 42967
Mean ITL (ms) 42.42 39.00 52.59 49.35 39.03 36.14
Mean TTFT (ms) 5131 5312 5809 5809 5927 5993
Mean TPOT (ms) 42.42 39.00 52.59 49.35 39.03 36.14
  • Max request concurrency: 64
  • Successful requests: 300
  • Total input tokens: 1,228,500
  • Total input text tokens: 1,228,500
  • Total generated tokens: 307,200
  • tp-size: 8
fp8
- Wint4Afp8
(sglang previous)
- Wint4Afp8
(per-token)
(pipeline optimized)
(tile shape optimized)
-
dp-size=1 dp-size=2 dp-size=1 dp-size=2 dp-size=1 dp-size=2
Request throughput (req/s) 0.66 0.90 0.51 0.74 0.67 0.96
Input token throughput (tok/s) 2716 3673 2097 3038 2722 3921
Output token throughput (tok/s) 679 918 524 759 681 981
Total token throughput (tok/s) 3395 4591 2627 3797 3403 4902
Mean E2E Latency (ms) 90289 66558 115518 80557 88295 62168
Mean ITL (ms) 44.63 56.03 57.48 68.73 42.61 50.63
Mean TTFT (ms) 44628 9238 56717 10250 44707 10375
Mean TPOT (ms) 44.63 56.03 57.48 68.73 42.61 50.63

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on a substantial performance uplift for w4afp8 Mixture-of-Experts (MoE) GEMM operations. It refactors the quantization pipeline to support per-token granularity, introduces distinct scaling mechanisms for both input and weight matrices, and incorporates low-level optimizations to leverage GPU architecture more effectively. The changes aim to improve throughput and reduce latency for quantized models, particularly in scenarios involving MoE layers.

Highlights

  • Performance Improvement: Significantly enhanced w4afp8 performance in MoE GEMM operations, achieving up to 90% improvement for batch size 32 with EP (ep-size=8) and 60% for TP (tp-size=8).
  • Per-Token Quantization: Reimplemented the GEMM pipeline with per-token quantization granularity, enabling dispatch for FP8 tokens and resolving previous limitations with BF16.
  • Optimized Pipeline and Memory Usage: Added a pipeline for A scale (per-token), optimized the overall pipeline, replaced LDS with LDSM for W4 weight, and estimated actual activation shapes in MoE to optimize block tile shapes.
  • Expanded Scale Handling: Introduced separate scale tensors for both A and B operands (activations and weights), allowing for more granular control and improved accuracy in mixed-precision GEMM operations.
  • Int4 to FP8 Lookup Table: Implemented a lookup table for converting int4 to FP8, specifically for int4 * fp8 operations, further optimizing the conversion process.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton kernel for interleaving int4 data and integrates it into the sglang framework. It also modifies the cutlass kernel to support scale B and adds a lookup table for int4 to fp8 conversion. The review comments suggest removing a redundant condition in the Triton kernel and adding a comment to explain the padding logic in the cutlass kernel.

dst_id = dst_row_id * cols_div4 + dst_col_id

valid = (
(col_id < cols_div4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition (col_id < cols_div4) appears to be redundant. Based on the loop condition mask_partition (which is partition_id < (cols // 64)), the maximum value of col_id is (cols // 64 - 1) * 16 + 15, which simplifies to cols / 4 - 1. Since cols_div4 is cols // 4, col_id will always be less than cols_div4 when mask_partition is true. The valid mask already includes mask_partition[:, None], so this check is unnecessary. Removing it could offer a minor performance improvement.

int64_t scale_k = k / 128;
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * scale_k : expert_id);
a_scales_offsets[expert_id] =
a_scales_base_as_int + (per_act_token ? expert_offset * (scale_k % 4 == 0 ? scale_k : scale_k * 4) : 0);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The padding logic (scale_k % 4 == 0 ? scale_k : scale_k * 4) is a bit unclear without context. Adding a comment to explain why this padding is necessary would improve code clarity. For example, explaining that it's for TMA alignment requirements for scale B.

@Wangzheee Wangzheee force-pushed the w4a8_per-token_kernel branch from 0b97239 to dd25c4d Compare March 22, 2026 03:14
@Wangzheee Wangzheee requested a review from HaiShaw as a code owner March 24, 2026 08:04
@Wangzheee Wangzheee force-pushed the w4a8_per-token_kernel branch from b84e186 to dd25c4d Compare March 24, 2026 08:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant