Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions unsloth/kernels/moe/grouped_gemm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def grouped_gemm_forward(
- `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP.
- When `permute_x` is True, `X` is expected to be of shape (num_tokens, K).
- When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous.
- `permute_y`: fused the permuation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.
- `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.
- `fuse_mul_pre`: fuse the multiplication of the routed input with topk_weights, only done in the first grouped GEMM in an MoE MLP as for Llama4. Do not use, since results in performance regression as it interrupts the GEMM mainloop.
- `fuse_mul_post`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training.

Expand Down Expand Up @@ -881,7 +881,7 @@ def grouped_gemm(
- `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP.
- When `permute_x` is True, `X` is expected to be of shape (num_tokens, K).
- When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous.
- `permute_y`: fused the permuation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.
- `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.
- `fuse_mul`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training.

X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""
Reference implementation of HF Qwen3 MoE block using grouped gemm.

The Qwen3MoeGroupedGEMMBlock is a reference torch-native implemention.
The Qwen3MoeGroupedGEMMBlock is a reference torch-native implementation.
Qwen3MoeFusedGroupedGEMMBlock is a version using the triton grouped gemm kernel.

NOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging.
Expand Down