[Feature] Add tile-based method of supporting large VPT in moe_fused_gate kernel#9579
[Feature] Add tile-based method of supporting large VPT in moe_fused_gate kernel#9579ltaodream wants to merge 1 commit intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @ltaodream, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the moe_fused_gate kernel by introducing a tile-based method to support larger 'Virtual Processing Time' (VPT) configurations. This change is crucial for enabling the use of the fused MoE gate operator with models such as Kimi-VL and Kimi-K2, which feature a higher number of group experts (64 and 384 respectively) than previously supported. The update ensures broader compatibility and improved performance for large-scale Mixture of Experts models.
Highlights
- Expanded MoE Support: The moe_fused_gate kernel now supports larger numbers of experts per group (VPT) beyond the previous limit of 32, up to 512, enabling compatibility with models like Kimi-VL (64 experts) and Kimi-K2 (384 experts).
- New Tiled CUDA Kernel: A new CUDA kernel, moe_fused_gate_tile_more_experts.cu, has been implemented to handle these larger VPT configurations efficiently using a tile-based approach.
- Dynamic Dispatch: The existing moe_fused_gate function now intelligently dispatches to the new tiled kernel when the computed VPT exceeds the original 32-expert limit.
- Flexible Expert Count: The constraint requiring num_experts to be a power of 2 has been relaxed to specifically allow 384 experts, accommodating the Kimi-K2 model.
- Improved Shared Expert Handling: The logic for assigning and weighting fused shared experts in topk.py has been refined for correctness.
- Benchmarking and Testing: New benchmark and test cases have been added to validate the functionality and performance of the tiled kernel for larger expert counts.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a tile-based method to support a larger number of experts in the moe_fused_gate kernel, enabling support for models like Kimi-vl and Kimi-k2. The changes are comprehensive, including new CUDA kernels, updates to Python logic, and new tests and benchmarks. The implementation is solid, but I have identified a few areas for improvement. My feedback focuses on enhancing numerical stability by preventing potential division-by-zero errors, improving maintainability by reducing code duplication, and addressing a silent limitation in the new static kernel. Overall, this is a valuable feature addition.
| else topk_weights[:, :-1].sum(dim=-1, keepdim=True) | ||
| else topk_weights[:, :-num_fused_shared_experts].sum(dim=-1, keepdim=True) | ||
| ) | ||
| topk_weights = topk_weights / topk_weights_sum |
There was a problem hiding this comment.
The division topk_weights / topk_weights_sum could result in NaN if topk_weights_sum is zero. This can happen if all selected expert weights are zero. To prevent this, it's safer to add a small epsilon to the denominator for numerical stability, as is done in the new benchmark file (bench_moe_fused_gate_tiled.py).
| topk_weights = topk_weights / topk_weights_sum | |
| topk_weights = topk_weights / (topk_weights_sum + 1e-9) |
| else topk_weights[:, :-1].sum(dim=-1, keepdim=True) | ||
| else topk_weights[:, :-num_fused_shared_experts].sum(dim=-1, keepdim=True) | ||
| ) | ||
| topk_weights = topk_weights / topk_weights_sum |
There was a problem hiding this comment.
Similar to the grouped_topk_gpu function, the division topk_weights / topk_weights_sum here can lead to NaN values if topk_weights_sum is zero. Please add a small epsilon to the denominator for numerical stability.
| topk_weights = topk_weights / topk_weights_sum | |
| topk_weights = topk_weights / (topk_weights_sum + 1e-9) |
| if (thread_group_idx == 0) { | ||
| float denom = output_sum; | ||
| for (int i = 0; i < topk; ++i) { | ||
| int64_t idx = topk * thread_row + i; | ||
| output_ptr[idx] = output_ptr[idx] / denom; | ||
| } | ||
| } |
There was a problem hiding this comment.
The renormalization loop can perform a division by zero if output_sum (aliased as denom) is zero, which would result in NaN or inf in the output weights. The static version of this kernel (moe_fused_gate_kernel_tiled_static) correctly handles this by checking if real_sum > 0.0f. A similar check should be added here for robustness.
if (thread_group_idx == 0) {
float denom = output_sum;
if (denom > 0.0f) {
for (int i = 0; i < topk; ++i) {
int64_t idx = topk * thread_row + i;
output_ptr[idx] = output_ptr[idx] / denom;
}
}
}
| if num_fused_shared_experts: | ||
| topk_ids[:, -1] = torch.randint( | ||
| low=num_experts, | ||
| high=num_experts + num_fused_shared_experts, | ||
| size=(topk_ids.size(0),), | ||
| dtype=topk_ids.dtype, | ||
| assert ( | ||
| topk >= num_fused_shared_experts | ||
| ), "topk must be >= num_fused_shared_experts" | ||
| # Assign the last N ids to all shared expert ids [num_experts, num_experts+N) | ||
| shared_ids = torch.arange( | ||
| num_experts, | ||
| num_experts + num_fused_shared_experts, | ||
| device=topk_ids.device, | ||
| dtype=topk_ids.dtype, | ||
| ) | ||
| topk_ids[:, -num_fused_shared_experts:] = shared_ids.unsqueeze(0).expand( | ||
| topk_ids.size(0), -1 | ||
| ) | ||
| topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor | ||
| # Set each shared expert's weight to sum(real_experts)/routed_scaling_factor | ||
| real_sum = topk_weights[:, :-num_fused_shared_experts].sum(dim=-1) | ||
| shared_weight = real_sum / routed_scaling_factor | ||
| topk_weights[:, -num_fused_shared_experts:] = shared_weight.unsqueeze( | ||
| -1 | ||
| ).expand(-1, num_fused_shared_experts) |
There was a problem hiding this comment.
There is significant code duplication between this function (biased_grouped_topk_impl) and grouped_topk_gpu for handling num_fused_shared_experts. The logic from lines 584-603 is nearly identical to lines 471-490. Consider refactoring this shared logic into a helper function to improve maintainability and reduce redundancy.
| const int MAX_TOPK = 32; | ||
| topk_excl_shared = min(topk_excl_shared, MAX_TOPK); |
There was a problem hiding this comment.
The static kernel uses a fixed-size array best_choice[MAX_TOPK] with MAX_TOPK = 32. The code then silently truncates topk_excl_shared to this limit. This could lead to incorrect behavior if a model requires topk > 32. It would be safer to add a static_assert or a TORCH_CHECK to ensure topk does not exceed MAX_TOPK.
| // Currently: specialize THREADS_PER_ROW=1 for selected NUM_EXPERTS with TILE=32 | ||
| if (num_experts == 384 && num_expert_group == 1) { | ||
| constexpr int THREADS_PER_ROW = 1; | ||
| constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; | ||
| constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; | ||
| int64_t rows_per_warp = ROWS_PER_WARP; | ||
| int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp; | ||
| int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA; | ||
|
|
||
| if (input.scalar_type() == at::kBFloat16) { | ||
| moe_fused_gate_kernel_tiled_static< | ||
| bfloat16_t, | ||
| 384, | ||
| THREADS_PER_ROW, | ||
| ROWS_PER_WARP, | ||
| ROWS_PER_CTA, | ||
| WARPS_PER_CTA, | ||
| TILE_VPT><<<num_blocks, block_dim, 0, stream>>>( | ||
| input.data_ptr(), | ||
| bias.data_ptr(), | ||
| output.data_ptr<float>(), | ||
| indices.data_ptr<int32_t>(), | ||
| num_rows, | ||
| topk, | ||
| num_fused_shared_experts, | ||
| routed_scaling_factor); | ||
| } else if (input.scalar_type() == at::kHalf) { | ||
| moe_fused_gate_kernel_tiled_static< | ||
| float16_t, | ||
| 384, | ||
| THREADS_PER_ROW, | ||
| ROWS_PER_WARP, | ||
| ROWS_PER_CTA, | ||
| WARPS_PER_CTA, | ||
| TILE_VPT><<<num_blocks, block_dim, 0, stream>>>( | ||
| input.data_ptr(), | ||
| bias.data_ptr(), | ||
| output.data_ptr<float>(), | ||
| indices.data_ptr<int32_t>(), | ||
| num_rows, | ||
| topk, | ||
| num_fused_shared_experts, | ||
| routed_scaling_factor); | ||
| } else if (input.scalar_type() == at::kFloat) { | ||
| moe_fused_gate_kernel_tiled_static< | ||
| float32_t, | ||
| 384, | ||
| THREADS_PER_ROW, | ||
| ROWS_PER_WARP, | ||
| ROWS_PER_CTA, | ||
| WARPS_PER_CTA, | ||
| TILE_VPT><<<num_blocks, block_dim, 0, stream>>>( | ||
| input.data_ptr(), | ||
| bias.data_ptr(), | ||
| output.data_ptr<float>(), | ||
| indices.data_ptr<int32_t>(), | ||
| num_rows, | ||
| topk, | ||
| num_fused_shared_experts, | ||
| routed_scaling_factor); | ||
| } else { | ||
| TORCH_CHECK(false, "Unsupported dtype for moe_fused_gate_tiled_static"); | ||
| } | ||
| } else if (num_experts == 64 && num_expert_group == 1) { | ||
| constexpr int THREADS_PER_ROW = 1; | ||
| constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; | ||
| constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; | ||
| int64_t rows_per_warp = ROWS_PER_WARP; | ||
| int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp; | ||
| int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA; | ||
|
|
||
| if (input.scalar_type() == at::kBFloat16) { | ||
| moe_fused_gate_kernel_tiled_static< | ||
| bfloat16_t, | ||
| 64, | ||
| THREADS_PER_ROW, | ||
| ROWS_PER_WARP, | ||
| ROWS_PER_CTA, | ||
| WARPS_PER_CTA, | ||
| TILE_VPT><<<num_blocks, block_dim, 0, stream>>>( | ||
| input.data_ptr(), | ||
| bias.data_ptr(), | ||
| output.data_ptr<float>(), | ||
| indices.data_ptr<int32_t>(), | ||
| num_rows, | ||
| topk, | ||
| num_fused_shared_experts, | ||
| routed_scaling_factor); | ||
| } else if (input.scalar_type() == at::kHalf) { | ||
| moe_fused_gate_kernel_tiled_static< | ||
| float16_t, | ||
| 64, | ||
| THREADS_PER_ROW, | ||
| ROWS_PER_WARP, | ||
| ROWS_PER_CTA, | ||
| WARPS_PER_CTA, | ||
| TILE_VPT><<<num_blocks, block_dim, 0, stream>>>( | ||
| input.data_ptr(), | ||
| bias.data_ptr(), | ||
| output.data_ptr<float>(), | ||
| indices.data_ptr<int32_t>(), | ||
| num_rows, | ||
| topk, | ||
| num_fused_shared_experts, | ||
| routed_scaling_factor); | ||
| } else if (input.scalar_type() == at::kFloat) { | ||
| moe_fused_gate_kernel_tiled_static< | ||
| float32_t, | ||
| 64, | ||
| THREADS_PER_ROW, | ||
| ROWS_PER_WARP, | ||
| ROWS_PER_CTA, | ||
| WARPS_PER_CTA, | ||
| TILE_VPT><<<num_blocks, block_dim, 0, stream>>>( | ||
| input.data_ptr(), | ||
| bias.data_ptr(), | ||
| output.data_ptr<float>(), | ||
| indices.data_ptr<int32_t>(), | ||
| num_rows, | ||
| topk, | ||
| num_fused_shared_experts, | ||
| routed_scaling_factor); | ||
| } else { | ||
| TORCH_CHECK(false, "Unsupported dtype for moe_fused_gate_tiled_static"); | ||
| } | ||
| } else { | ||
| TORCH_CHECK(false, "moe_fused_gate_tiled_static: unsupported combination"); | ||
| } |
There was a problem hiding this comment.
There is a large amount of duplicated code in the host launcher moe_fused_gate_tiled_static. The logic for dispatching based on input.scalar_type() is repeated for num_experts=384 and num_experts=64. This could be significantly simplified by using a helper function or a macro to handle the dtype-based dispatch, reducing boilerplate and improving maintainability. A similar pattern of duplication exists in the moe_fused_gate_tiled host launcher.
|
MoE Fused Gate: add tiled path and static specializations for large VPT (64/384), unify switch-case dispatch, and provide multi-dtype benchmarking.
Results of kimi-vl and kimi-k2 Bench_hf result of kimi-vl Summary MMMU Benchmark on Kimi-VL-A3B-Instruct Model. Test Results: test_vlm_models.py [#4491 ] (Overall Acc: 0.5244)Summary ResultsThe overall accuracy on the MMMU validation set is 0.5244.
Click to view detailed scores by category{
"Overall-Art and Design": { "num": 120, "acc": 0.7 },
"Art": { "num": 30, "acc": 0.7 },
"Art_Theory": { "num": 30, "acc": 0.83333 },
"Design": { "num": 30, "acc": 0.8 },
"Music": { "num": 30, "acc": 0.46667 },
"Overall-Business": { "num": 150, "acc": 0.47333 },
"Accounting": { "num": 30, "acc": 0.5 },
"Economics": { "num": 30, "acc": 0.5 },
"Finance": { "num": 30, "acc": 0.36667 },
"Manage": { "num": 30, "acc": 0.4 },
"Marketing": { "num": 30, "acc": 0.6 },
"Overall-Science": { "num": 150, "acc": 0.47333 },
"Biology": { "num": 30, "acc": 0.56667 },
"Chemistry": { "num": 30, "acc": 0.36667 },
"Geography": { "num": 30, "acc": 0.63333 },
"Math": { "num": 30, "acc": 0.36667 },
"Physics": { "num": 30, "acc": 0.43333 },
"Overall-Health and Medicine": { "num": 150, "acc": 0.47333 },
"Basic_Medical_Science": { "num": 30, "acc": 0.4 },
"Clinical_Medicine": { "num": 30, "acc": 0.5 },
"Diagnostics_and_Laboratory_Medicine": { "num": 30, "acc": 0.36667 },
"Pharmacy": { "num": 30, "acc": 0.53333 },
"Public_Health": { "num": 30, "acc": 0.56667 },
"Overall-Humanities and Social Science": { "num": 120, "acc": 0.65833 },
"History": { "num": 30, "acc": 0.6 },
"Literature": { "num": 30, "acc": 0.83333 },
"Sociology": { "num": 30, "acc": 0.56667 },
"Psychology": { "num": 30, "acc": 0.63333 },
"Overall-Tech and Engineering": { "num": 210, "acc": 0.45714 },
"Agriculture": { "num": 30, "acc": 0.56667 },
"Architecture_and_Engineering": { "num": 30, "acc": 0.46667 },
"Computer_Science": { "num": 30, "acc": 0.53333 },
"Electronics": { "num": 30, "acc": 0.36667 },
"Energy_and_Power": { "num": 30, "acc": 0.46667 },
"Materials": { "num": 30, "acc": 0.53333 },
"Mechanical_Engineering": { "num": 30, "acc": 0.26667 },
"Overall": { "num": 900, "acc": 0.52444 }
} |
|
@FlamingoPg I've completed rebasing the branch onto the latest main. The changes are now ready for review. |
> assert output_check, (
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
)
E AssertionError: Output mismatch at seq_length 1024, dtype torch.float32, params (384, 1, 1, 8), num_fused_shared_experts 2
E assert False
tests/test_moe_fused_gate.py:102: AssertionErrorSome CI is failing, you need to take a look. @ltaodream |
Hi, I encountered this issue when the seq_length is very long (e.g., 32768 or 65536). |


Motivation
The kimi-vl supported at #5383 and kimi-k2 model cannot use the fuse moe gate operator. This operator currently does not support the case where the group expert is greater than 32. The group expert of kimi-vl is 64, the group expert of kimi-k2 is 384, and this operator cannot be used at present.
Modifications
@ttaohe @Misaka9468 The three of us completed the PR "Add tile-based method of supporting large VPT in moe_fused_gate kernel" together.
Since rebasing rewrote the commit history, GitHub can't properly update the original PR. so closed #6946 and opening a fresh PR with the cleaned-up single commit for easier review. Original PR for reference: #6946 ([Feature] Add tile-based method of supporting large VPT in moe_fused_gate kernel)
Checklist