Skip to content

Support flashinfer_cutlass Fused MoE on SM90 for FP8#18388

Open
b8zhong wants to merge 5 commits intosgl-project:mainfrom
bzhng-development:brayden/fix-sm90-flashinfer-cutlass
Open

Support flashinfer_cutlass Fused MoE on SM90 for FP8#18388
b8zhong wants to merge 5 commits intosgl-project:mainfrom
bzhng-development:brayden/fix-sm90-flashinfer-cutlass

Conversation

@b8zhong
Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong commented Feb 7, 2026

Motivation

Currently, we use cutlass_fused_moe in NVFP4, but it supports FP8 on SM90 and SM100 (in SM100 case, it's only used for large scale prefill, and for decode, flashinfer_trtllm is better, thus not really for that use case, but rather just to support SM90 and do some comparisons to the cutlass and original triton backend created in #7278.

In TP MoE it seems triton is quite efficient. I didn't try the EP scenario. It can be left in the future

Modifications

Add it

Benchmarking and Profiling

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3.1 --trust-remote-code --tp 8 --moe-runner-backend=flashinfer_cutlass # and triton, cutlass
SGLANG_TORCH_PROFILER_DIR="./" \
python -m sglang.bench_one_batch_server \
  --model baseten-admin/glm-4.7-fp8-attn-fp4-mlp \
  --base-url http://localhost:30000 \
  --batch-size 16 \
  --input-len 1024 \
  --output-len 2048 \
  --profile \
  --profile-steps 10 \
  --show-report \
  --skip-warmup

For flashinfer_cutlass
Screenshot 2026-02-06 at 9 57 10 PM

Assoc. GSM8K:

python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 400
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:53<00:00, 24.58it/s]
Accuracy: 0.950
Invalid: 0.001
Latency: 53.804 s
Output throughput: 2398.086 token/s

cutlass
Screenshot 2026-02-06 at 9 42 38 PM

triton
Screenshot 2026-02-06 at 9 58 38 PM

Assoc. GSM8K (baseline):

python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 400
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:44<00:00, 29.42it/s]
Accuracy: 0.952
Invalid: 0.000
Latency: 44.986 s
Output throughput: 2789.070 token/s

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 7, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Mixture-of-Experts (MoE) capabilities by integrating a new flashinfer_cutlass backend, specifically tailored for FP8 quantization on NVIDIA Hopper (SM90) GPUs. The primary goal is to leverage FlashInfer's optimized fused MoE kernels to achieve superior performance for FP8 models. This involves not only adding the new backend but also refining the existing MoE runner logic to support diverse quantization schemes and activation functions, while ensuring proper weight handling. The changes aim to provide a more performant and robust MoE execution environment, particularly for modern GPU architectures.

Highlights

  • New MoE Backend Integration: Introduced support for the flashinfer_cutlass backend for Fused Mixture-of-Experts (MoE) layers, specifically optimized for FP8 quantization on NVIDIA SM90 (Hopper) architectures.
  • Code Refactoring for Modularity: Refactored the existing cutlass FP8 MoE implementation into a dedicated private method (_apply_cutlass_fp8) within the FP8MoERunner class, improving code organization and separation of concerns.
  • Weight Projection Order Handling: Added a new property load_up_proj_weight_first to ensure correct weight loading order ([Up, Gate]) for the flashinfer_cutlass kernel when handling gated MoE configurations.
  • Dynamic Scaling and Activation Support: Implemented dynamic handling of quantization scales (block-wise vs. per-tensor) and activation types (e.g., Swiglu, Relu2) for the flashinfer_cutlass integration, enhancing flexibility and performance.
  • Documentation Updates: Updated the expert_parallelism.md documentation to reflect the enhanced performance characteristics of the triton backend on SM90 and to clarify the specific use cases for the flashinfer_cutlass backend (Block-wise FP8 on Hopper and NVFP4 on Blackwell).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • docs/advanced_features/expert_parallelism.md
    • Updated the description for the triton backend to highlight its high performance in most SM90 (Hopper) scenarios.
    • Revised the use cases for the flashinfer_cutlass backend to specify 'Block-wise FP8 on Hopper and NVFP4 on Blackwell'.
  • python/sglang/srt/layers/quantization/fp8.py
    • Added import for next_power_of_2 utility function.
    • Introduced load_up_proj_weight_first property in FP8MoERunner to manage weight projection order for flashinfer_cutlass.
    • Implemented the flashinfer_cutlass execution path, integrating flashinfer.fused_moe.cutlass_fused_moe with support for block-wise and per-tensor quantization scales, and activation type mapping.
    • Refactored the original cutlass FP8 MoE logic into a new private method _apply_cutlass_fp8.
Activity
  • The author, b8zhong, motivated this change by highlighting the existing cutlass_fused_moe's support for FP8 on SM90 and SM100, aiming to integrate this capability and compare it with other backends.
  • Detailed benchmarking and profiling results were provided, including commands for launching the server and running benchmarks, along with screenshots of profiling reports for flashinfer_cutlass, cutlass, and triton backends.
  • GSM8K accuracy and throughput metrics were included for flashinfer_cutlass and a baseline, demonstrating the performance impact of the new backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the flashinfer_cutlass backend for Fused MoE on SM90 for FP8, which is a great performance enhancement. The implementation is well-structured, adding a new code path for this backend and refactoring the existing cutlass logic for better clarity. The documentation is also updated accordingly. My main feedback is to add more robust error handling for activation functions to prevent potential runtime errors.

Comment on lines +1427 to +1431
activation = self.moe_runner_config.activation
activation_type = {
"silu": ActivationType.Swiglu,
"relu2": ActivationType.Relu2,
}[activation]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The direct dictionary access [activation] to determine activation_type is unsafe. If an unsupported activation function is provided (e.g., 'gelu'), this will raise a KeyError and crash the server. It's better to handle this case gracefully by checking if the activation is supported and raising a more informative error like NotImplementedError if it's not.

            activation = self.moe_runner_config.activation
            activation_map = {
                "silu": ActivationType.Swiglu,
                "relu2": ActivationType.Relu2,
            }
            activation_type = activation_map.get(activation)
            if activation_type is None:
                raise NotImplementedError(
                    f"Unsupported activation function for flashinfer_cutlass: {activation}"
                )

@b8zhong
Copy link
Copy Markdown
Collaborator Author

b8zhong commented Feb 7, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant