From b33a1df84d4316728a16b1eb18bc53e6116957e3 Mon Sep 17 00:00:00 2001 From: kilavvy <140459108+kilavvy@users.noreply.github.com> Date: Sun, 22 Jun 2025 10:01:53 +0200 Subject: [PATCH 1/2] Update qwen3_moe.py --- unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py index 2bc9cc624..37f001aef 100644 --- a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py +++ b/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py @@ -25,7 +25,7 @@ """ Reference implementation of HF Qwen3 MoE block using grouped gemm. -The Qwen3MoeGroupedGEMMBlock is a reference torch-native implemention. +The Qwen3MoeGroupedGEMMBlock is a reference torch-native implementation. Qwen3MoeFusedGroupedGEMMBlock is a version using the triton grouped gemm kernel. NOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging. From 25fa84bc7403f4955cb8ecea153f817325a12402 Mon Sep 17 00:00:00 2001 From: kilavvy <140459108+kilavvy@users.noreply.github.com> Date: Sun, 22 Jun 2025 10:02:29 +0200 Subject: [PATCH 2/2] Update interface.py --- unsloth/kernels/moe/grouped_gemm/interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/moe/grouped_gemm/interface.py b/unsloth/kernels/moe/grouped_gemm/interface.py index 3cb186984..99c58b36e 100644 --- a/unsloth/kernels/moe/grouped_gemm/interface.py +++ b/unsloth/kernels/moe/grouped_gemm/interface.py @@ -114,7 +114,7 @@ def grouped_gemm_forward( - `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP. - When `permute_x` is True, `X` is expected to be of shape (num_tokens, K). - When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous. - - `permute_y`: fused the permuation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP. + - `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP. - `fuse_mul_pre`: fuse the multiplication of the routed input with topk_weights, only done in the first grouped GEMM in an MoE MLP as for Llama4. Do not use, since results in performance regression as it interrupts the GEMM mainloop. - `fuse_mul_post`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training. @@ -881,7 +881,7 @@ def grouped_gemm( - `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP. - When `permute_x` is True, `X` is expected to be of shape (num_tokens, K). - When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous. - - `permute_y`: fused the permuation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP. + - `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP. - `fuse_mul`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training. X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.