Skip to content

[Feature] Add tile-based method of supporting large VPT in moe_fused_gate kernel#9579

Open
ltaodream wants to merge 1 commit intosgl-project:mainfrom
ltaodream:ltaodream_dev
Open

[Feature] Add tile-based method of supporting large VPT in moe_fused_gate kernel#9579
ltaodream wants to merge 1 commit intosgl-project:mainfrom
ltaodream:ltaodream_dev

Conversation

@ltaodream
Copy link
Copy Markdown

@ltaodream ltaodream commented Aug 25, 2025

Motivation

The kimi-vl supported at #5383 and kimi-k2 model cannot use the fuse moe gate operator. This operator currently does not support the case where the group expert is greater than 32. The group expert of kimi-vl is 64, the group expert of kimi-k2 is 384, and this operator cannot be used at present.

Modifications

@ttaohe @Misaka9468 The three of us completed the PR "Add tile-based method of supporting large VPT in moe_fused_gate kernel" together.

Since rebasing rewrote the commit history, GitHub can't properly update the original PR. so closed #6946 and opening a fresh PR with the cleaned-up single commit for easier review. Original PR for reference: #6946 ([Feature] Add tile-based method of supporting large VPT in moe_fused_gate kernel)

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ltaodream, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the moe_fused_gate kernel by introducing a tile-based method to support larger 'Virtual Processing Time' (VPT) configurations. This change is crucial for enabling the use of the fused MoE gate operator with models such as Kimi-VL and Kimi-K2, which feature a higher number of group experts (64 and 384 respectively) than previously supported. The update ensures broader compatibility and improved performance for large-scale Mixture of Experts models.

Highlights

  • Expanded MoE Support: The moe_fused_gate kernel now supports larger numbers of experts per group (VPT) beyond the previous limit of 32, up to 512, enabling compatibility with models like Kimi-VL (64 experts) and Kimi-K2 (384 experts).
  • New Tiled CUDA Kernel: A new CUDA kernel, moe_fused_gate_tile_more_experts.cu, has been implemented to handle these larger VPT configurations efficiently using a tile-based approach.
  • Dynamic Dispatch: The existing moe_fused_gate function now intelligently dispatches to the new tiled kernel when the computed VPT exceeds the original 32-expert limit.
  • Flexible Expert Count: The constraint requiring num_experts to be a power of 2 has been relaxed to specifically allow 384 experts, accommodating the Kimi-K2 model.
  • Improved Shared Expert Handling: The logic for assigning and weighting fused shared experts in topk.py has been refined for correctness.
  • Benchmarking and Testing: New benchmark and test cases have been added to validate the functionality and performance of the tiled kernel for larger expert counts.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@FlamingoPg FlamingoPg self-assigned this Aug 25, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a tile-based method to support a larger number of experts in the moe_fused_gate kernel, enabling support for models like Kimi-vl and Kimi-k2. The changes are comprehensive, including new CUDA kernels, updates to Python logic, and new tests and benchmarks. The implementation is solid, but I have identified a few areas for improvement. My feedback focuses on enhancing numerical stability by preventing potential division-by-zero errors, improving maintainability by reducing code duplication, and addressing a silent limitation in the new static kernel. Overall, this is a valuable feature addition.

else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
else topk_weights[:, :-num_fused_shared_experts].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The division topk_weights / topk_weights_sum could result in NaN if topk_weights_sum is zero. This can happen if all selected expert weights are zero. To prevent this, it's safer to add a small epsilon to the denominator for numerical stability, as is done in the new benchmark file (bench_moe_fused_gate_tiled.py).

Suggested change
topk_weights = topk_weights / topk_weights_sum
topk_weights = topk_weights / (topk_weights_sum + 1e-9)

else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
else topk_weights[:, :-num_fused_shared_experts].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the grouped_topk_gpu function, the division topk_weights / topk_weights_sum here can lead to NaN values if topk_weights_sum is zero. Please add a small epsilon to the denominator for numerical stability.

Suggested change
topk_weights = topk_weights / topk_weights_sum
topk_weights = topk_weights / (topk_weights_sum + 1e-9)

Comment on lines +255 to +261
if (thread_group_idx == 0) {
float denom = output_sum;
for (int i = 0; i < topk; ++i) {
int64_t idx = topk * thread_row + i;
output_ptr[idx] = output_ptr[idx] / denom;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The renormalization loop can perform a division by zero if output_sum (aliased as denom) is zero, which would result in NaN or inf in the output weights. The static version of this kernel (moe_fused_gate_kernel_tiled_static) correctly handles this by checking if real_sum > 0.0f. A similar check should be added here for robustness.

  if (thread_group_idx == 0) {
    float denom = output_sum;
    if (denom > 0.0f) {
      for (int i = 0; i < topk; ++i) {
        int64_t idx = topk * thread_row + i;
        output_ptr[idx] = output_ptr[idx] / denom;
      }
    }
  }

Comment on lines 573 to +603
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
assert (
topk >= num_fused_shared_experts
), "topk must be >= num_fused_shared_experts"
# Assign the last N ids to all shared expert ids [num_experts, num_experts+N)
shared_ids = torch.arange(
num_experts,
num_experts + num_fused_shared_experts,
device=topk_ids.device,
dtype=topk_ids.dtype,
)
topk_ids[:, -num_fused_shared_experts:] = shared_ids.unsqueeze(0).expand(
topk_ids.size(0), -1
)
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
# Set each shared expert's weight to sum(real_experts)/routed_scaling_factor
real_sum = topk_weights[:, :-num_fused_shared_experts].sum(dim=-1)
shared_weight = real_sum / routed_scaling_factor
topk_weights[:, -num_fused_shared_experts:] = shared_weight.unsqueeze(
-1
).expand(-1, num_fused_shared_experts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between this function (biased_grouped_topk_impl) and grouped_topk_gpu for handling num_fused_shared_experts. The logic from lines 584-603 is nearly identical to lines 471-490. Consider refactoring this shared logic into a helper function to improve maintainability and reduce redundancy.

Comment on lines +302 to +303
const int MAX_TOPK = 32;
topk_excl_shared = min(topk_excl_shared, MAX_TOPK);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The static kernel uses a fixed-size array best_choice[MAX_TOPK] with MAX_TOPK = 32. The code then silently truncates topk_excl_shared to this limit. This could lead to incorrect behavior if a model requires topk > 32. It would be safer to add a static_assert or a TORCH_CHECK to ensure topk does not exceed MAX_TOPK.

Comment on lines +378 to +505
// Currently: specialize THREADS_PER_ROW=1 for selected NUM_EXPERTS with TILE=32
if (num_experts == 384 && num_expert_group == 1) {
constexpr int THREADS_PER_ROW = 1;
constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
int64_t rows_per_warp = ROWS_PER_WARP;
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;

if (input.scalar_type() == at::kBFloat16) {
moe_fused_gate_kernel_tiled_static<
bfloat16_t,
384,
THREADS_PER_ROW,
ROWS_PER_WARP,
ROWS_PER_CTA,
WARPS_PER_CTA,
TILE_VPT><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
topk,
num_fused_shared_experts,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_tiled_static<
float16_t,
384,
THREADS_PER_ROW,
ROWS_PER_WARP,
ROWS_PER_CTA,
WARPS_PER_CTA,
TILE_VPT><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
topk,
num_fused_shared_experts,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_tiled_static<
float32_t,
384,
THREADS_PER_ROW,
ROWS_PER_WARP,
ROWS_PER_CTA,
WARPS_PER_CTA,
TILE_VPT><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
topk,
num_fused_shared_experts,
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported dtype for moe_fused_gate_tiled_static");
}
} else if (num_experts == 64 && num_expert_group == 1) {
constexpr int THREADS_PER_ROW = 1;
constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
int64_t rows_per_warp = ROWS_PER_WARP;
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;

if (input.scalar_type() == at::kBFloat16) {
moe_fused_gate_kernel_tiled_static<
bfloat16_t,
64,
THREADS_PER_ROW,
ROWS_PER_WARP,
ROWS_PER_CTA,
WARPS_PER_CTA,
TILE_VPT><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
topk,
num_fused_shared_experts,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_tiled_static<
float16_t,
64,
THREADS_PER_ROW,
ROWS_PER_WARP,
ROWS_PER_CTA,
WARPS_PER_CTA,
TILE_VPT><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
topk,
num_fused_shared_experts,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_tiled_static<
float32_t,
64,
THREADS_PER_ROW,
ROWS_PER_WARP,
ROWS_PER_CTA,
WARPS_PER_CTA,
TILE_VPT><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
topk,
num_fused_shared_experts,
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported dtype for moe_fused_gate_tiled_static");
}
} else {
TORCH_CHECK(false, "moe_fused_gate_tiled_static: unsupported combination");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a large amount of duplicated code in the host launcher moe_fused_gate_tiled_static. The logic for dispatching based on input.scalar_type() is repeated for num_experts=384 and num_experts=64. This could be significantly simplified by using a helper function or a macro to handle the dtype-based dispatch, reducing boilerplate and improving maintainability. A similar pattern of duplication exists in the moe_fused_gate_tiled host launcher.

@ltaodream
Copy link
Copy Markdown
Author

ltaodream commented Aug 25, 2025

MoE Fused Gate: add tiled path and static specializations for large VPT (64/384), unify switch-case dispatch, and provide multi-dtype benchmarking.

  • Added a tiled implementation for large VPT and two static specializations for THREADS_PER_ROW=1: (num_experts=64, group=1) and (num_experts=384, group=1). These are exposed via consistent switch-case dispatch using LAUNCH_MOE_GATE_TILED_CONFIG.
  • Kept existing template fast paths for small VPT (e.g., 128/256), and route all other large-VPT cases to the generic tiled kernel.
  • Refactored tiled declarations/macros into moe_fused_gate_tiled.h to keep moe_fused_gate.cu clean.
  • Added a benchmark (bf16/fp16/fp32) comparing Original (eager, compile-static, compile-dynamic) vs SGL Kernel on representative large-VPT configs.

Results of kimi-vl and kimi-k2
image

Benchmmmu result of kimi-vl
image

Bench_hf result of kimi-vl

python benchmark/mmmu/bench_hf.py --model-path moonshotai/Kimi-VL-A3B-Instruct

answers saved to: moonshotai/Kimi-VL-A3B-Instruct_val_hf.json
Evaluating...
{'Accounting': {'acc': 0.4, 'num': 30},
 'Agriculture': {'acc': 0.562, 'num': 16},
 'Architecture_and_Engineering': {'acc': 0.367, 'num': 30},
 'Art': {'acc': 0.5, 'num': 30},
 'Art_Theory': {'acc': 0.3, 'num': 30},
 'Basic_Medical_Science': {'acc': 0.667, 'num': 30},
 'Biology': {'acc': 0.333, 'num': 30},
 'Chemistry': {'acc': 0.2, 'num': 30},
 'Clinical_Medicine': {'acc': 0.4, 'num': 30},
 'Computer_Science': {'acc': 0.6, 'num': 30},
 'Design': {'acc': 0.533, 'num': 30},
 'Diagnostics_and_Laboratory_Medicine': {'acc': 0.367, 'num': 30},
 'Economics': {'acc': 0.533, 'num': 30},
 'Electronics': {'acc': 0.333, 'num': 30},
 'Energy_and_Power': {'acc': 0.4, 'num': 30},
 'Finance': {'acc': 0.4, 'num': 30},
 'Geography': {'acc': 0.533, 'num': 30},
 'History': {'acc': 0.5, 'num': 30},
 'Literature': {'acc': 0.828, 'num': 29},
 'Manage': {'acc': 0.333, 'num': 30},
 'Marketing': {'acc': 0.6, 'num': 30},
 'Materials': {'acc': 0.367, 'num': 30},
 'Math': {'acc': 0.4, 'num': 30},
 'Mechanical_Engineering': {'acc': 0.467, 'num': 30},
 'Music': {'acc': 0.267, 'num': 30},
 'Overall': {'acc': 0.44, 'num': 885},
 'Overall-Art and Design': {'acc': 0.4, 'num': 120},
 'Overall-Business': {'acc': 0.453, 'num': 150},
 'Overall-Health and Medicine': {'acc': 0.453, 'num': 150},
 'Overall-Humanities and Social Science': {'acc': 0.563, 'num': 119},
 'Overall-Science': {'acc': 0.353, 'num': 150},
 'Overall-Tech and Engineering': {'acc': 0.434, 'num': 196},
 'Pharmacy': {'acc': 0.467, 'num': 30},
 'Physics': {'acc': 0.3, 'num': 30},
 'Psychology': {'acc': 0.367, 'num': 30},
 'Public_Health': {'acc': 0.367, 'num': 30},
 'Sociology': {'acc': 0.567, 'num': 30}}
eval out saved to moonshotai/Kimi-VL-A3B-Instruct_val_hf.json
Overall accuracy: 0.44

Summary MMMU Benchmark on Kimi-VL-A3B-Instruct Model.
Results as follows:
[bench_sglang.py]: 0.512
[bench_hf.py]: 0.44

Test Results: test_vlm_models.py [#4491 ] (Overall Acc: 0.5244)

Summary Results

The overall accuracy on the MMMU validation set is 0.5244.

Tasks Version Filter n-shot Metric Value Stderr
mmmu_val 0 none 0 mmmu_acc 0.5244 ± N/A

Click to view detailed scores by category
{
    "Overall-Art and Design": { "num": 120, "acc": 0.7 },
    "Art": { "num": 30, "acc": 0.7 },
    "Art_Theory": { "num": 30, "acc": 0.83333 },
    "Design": { "num": 30, "acc": 0.8 },
    "Music": { "num": 30, "acc": 0.46667 },
    "Overall-Business": { "num": 150, "acc": 0.47333 },
    "Accounting": { "num": 30, "acc": 0.5 },
    "Economics": { "num": 30, "acc": 0.5 },
    "Finance": { "num": 30, "acc": 0.36667 },
    "Manage": { "num": 30, "acc": 0.4 },
    "Marketing": { "num": 30, "acc": 0.6 },
    "Overall-Science": { "num": 150, "acc": 0.47333 },
    "Biology": { "num": 30, "acc": 0.56667 },
    "Chemistry": { "num": 30, "acc": 0.36667 },
    "Geography": { "num": 30, "acc": 0.63333 },
    "Math": { "num": 30, "acc": 0.36667 },
    "Physics": { "num": 30, "acc": 0.43333 },
    "Overall-Health and Medicine": { "num": 150, "acc": 0.47333 },
    "Basic_Medical_Science": { "num": 30, "acc": 0.4 },
    "Clinical_Medicine": { "num": 30, "acc": 0.5 },
    "Diagnostics_and_Laboratory_Medicine": { "num": 30, "acc": 0.36667 },
    "Pharmacy": { "num": 30, "acc": 0.53333 },
    "Public_Health": { "num": 30, "acc": 0.56667 },
    "Overall-Humanities and Social Science": { "num": 120, "acc": 0.65833 },
    "History": { "num": 30, "acc": 0.6 },
    "Literature": { "num": 30, "acc": 0.83333 },
    "Sociology": { "num": 30, "acc": 0.56667 },
    "Psychology": { "num": 30, "acc": 0.63333 },
    "Overall-Tech and Engineering": { "num": 210, "acc": 0.45714 },
    "Agriculture": { "num": 30, "acc": 0.56667 },
    "Architecture_and_Engineering": { "num": 30, "acc": 0.46667 },
    "Computer_Science": { "num": 30, "acc": 0.53333 },
    "Electronics": { "num": 30, "acc": 0.36667 },
    "Energy_and_Power": { "num": 30, "acc": 0.46667 },
    "Materials": { "num": 30, "acc": 0.53333 },
    "Mechanical_Engineering": { "num": 30, "acc": 0.26667 },
    "Overall": { "num": 900, "acc": 0.52444 }
}

@ltaodream
Copy link
Copy Markdown
Author

@FlamingoPg I've completed rebasing the branch onto the latest main. The changes are now ready for review.

@FlamingoPg
Copy link
Copy Markdown
Collaborator

>       assert output_check, (
            f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
            f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
        )
E       AssertionError: Output mismatch at seq_length 1024, dtype torch.float32, params (384, 1, 1, 8), num_fused_shared_experts 2
E       assert False

tests/test_moe_fused_gate.py:102: AssertionError

Some CI is failing, you need to take a look. @ltaodream

@ttaohe
Copy link
Copy Markdown

ttaohe commented Aug 27, 2025

>       assert output_check, (
            f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
            f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
        )
E       AssertionError: Output mismatch at seq_length 1024, dtype torch.float32, params (384, 1, 1, 8), num_fused_shared_experts 2
E       assert False

tests/test_moe_fused_gate.py:102: AssertionError

Some CI is failing, you need to take a look. @ltaodream

Hi, I encountered this issue when the seq_length is very long (e.g., 32768 or 65536).
The absolute difference in the Top‑K experts’ weights is shown below:
image
Could you please let me know whether such differences are expected or indicate a potential problem? Also, I couldn’t reproduce the difference reported in CI with the short seq_length (1024); at least on my machine (L20), I haven’t encountered this issue. Thanku. @FlamingoPg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants