-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[LoRA] Support FusedMoE LoRA Triton kernel for mxfp4 #28971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7bdeb66
f41c9f0
03aeda3
72d9abf
9850bc2
ce1e26a
efa7fc5
4314201
867d973
fd5d74b
beec2f2
492ab6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,250 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| Test modular OAI Triton MoE | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.utils.import_utils import has_triton_kernels | ||
|
|
||
| if not has_triton_kernels(): | ||
| pytest.skip( | ||
| "triton_kernels not found, skipping all related tests", | ||
| allow_module_level=True, | ||
| ) | ||
|
|
||
| from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig | ||
| from triton_kernels.numerics import InFlexData | ||
| from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp | ||
| from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor | ||
| from triton_kernels.tensor_details import layout | ||
| from triton_kernels.testing import assert_close | ||
|
|
||
| from vllm.config import VllmConfig, set_current_vllm_config | ||
| from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config | ||
| from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( | ||
| OAITritonExperts, | ||
| UnfusedOAITritonExperts, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel | ||
| from vllm.model_executor.layers.fused_moe.prepare_finalize import ( | ||
| MoEPrepareAndFinalizeNoEP, | ||
| ) | ||
| from vllm.model_executor.layers.utils import shuffle_weight | ||
| from vllm.platforms import current_platform | ||
|
|
||
| MNK = [ | ||
| (1, 512, 384), | ||
| (1, 2880, 2880), | ||
| (2, 512, 384), | ||
| (2, 2880, 2880), | ||
| (32, 2880, 2880), | ||
| (64, 2880, 2880), | ||
| ] | ||
|
|
||
|
|
||
| def unshuffle_weight(w: torch.Tensor): | ||
| first = w[..., ::2] | ||
| second = w[..., 1::2] | ||
| return torch.concat((first, second), dim=-1) | ||
|
|
||
|
|
||
| def make_weights(dtype, k, n, e): | ||
| w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda") | ||
| w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda") | ||
|
|
||
| w2 = torch.randn((e, n, k), dtype=dtype, device="cuda") | ||
| w2_bias = torch.randn((e, k), dtype=dtype, device="cuda") | ||
|
|
||
| w1_tri = w1.clone() | ||
| w2_tri = w2.clone() | ||
|
|
||
| w1_bias_tri = w1_bias.clone() | ||
| w2_bias_tri = w2_bias.clone() | ||
| w1_bias_tri = w1_bias_tri.to(torch.float32) | ||
| w2_bias_tri = w2_bias_tri.to(torch.float32) | ||
|
|
||
| # shuffle weights | ||
| w1_tri = shuffle_weight(w1_tri) | ||
| w1_bias_tri = shuffle_weight(w1_bias_tri) | ||
|
|
||
| # quant triton_weights | ||
| w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) | ||
| w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1) | ||
| w1 = unshuffle_weight(w1) | ||
|
|
||
| w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) | ||
| w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1) | ||
|
|
||
| num_warps = 8 | ||
| w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) | ||
| w_scale_layout, w_scale_layout_opts = ( | ||
| layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps) | ||
| ) | ||
|
|
||
| w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts) | ||
| w1_scale_tri = convert_layout( | ||
| wrap_torch_tensor(w1_scale_tri), | ||
| w_scale_layout, | ||
| **w_scale_layout_opts, | ||
| ) | ||
|
|
||
| w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts) | ||
| w2_scale_tri = convert_layout( | ||
| wrap_torch_tensor(w2_scale_tri), | ||
| w_scale_layout, | ||
| **w_scale_layout_opts, | ||
| ) | ||
|
|
||
| w1_precision_config = PrecisionConfig( | ||
| weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) | ||
| ) | ||
| w2_precision_config = PrecisionConfig( | ||
| weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) | ||
| ) | ||
|
|
||
| return ( | ||
| w1, | ||
| w2, | ||
| w1_bias, | ||
| w2_bias, | ||
| w1_tri, | ||
| w2_tri, | ||
| w1_bias_tri, | ||
| w2_bias_tri, | ||
| w1_precision_config, | ||
| w2_precision_config, | ||
| ) | ||
|
|
||
|
|
||
| def swiglu(x, alpha: float = 1.702, limit: float = 1.0): | ||
| # Note we add an extra bias of 1 to the linear layer | ||
| x_glu, x_linear = torch.chunk(x, 2, dim=-1) | ||
| if limit is not None: | ||
| x_glu = x_glu.clamp(max=limit) | ||
| out_glu = x_glu * torch.sigmoid(alpha * x_glu) | ||
| if limit is not None: | ||
| x_linear = x_linear.clamp(min=-limit, max=limit) | ||
| return out_glu * (x_linear + 1) | ||
|
|
||
|
|
||
| def torch_moe_impl( | ||
| hidden_states: torch.Tensor, # (M, K) | ||
| w1: torch.Tensor, # (E, K, 2N) | ||
| w2: torch.Tensor, # (E, N, K) | ||
| w1_bias: torch.Tensor, # (E, 2N) | ||
| w2_bias: torch.Tensor, # (E, K) | ||
| topk_weights: torch.Tensor, # (M, topk) | ||
| topk_ids: torch.Tensor, # (M, topk) | ||
| ): | ||
| w1 = w1[topk_ids, ...] | ||
| w1_bias = w1_bias[topk_ids, ...] | ||
| hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias | ||
| hidden_states = swiglu(hidden_states, limit=7) | ||
|
|
||
| w2 = w2[topk_ids, ...] | ||
| w2_bias = w2_bias[topk_ids, ...] | ||
| hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias | ||
|
|
||
| # Weighted sum of experts | ||
| hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights) | ||
| return hidden_states | ||
|
|
||
|
|
||
| def oai_triton_moe_impl( | ||
| x: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w1_scale: "PrecisionConfig", | ||
| w2_scale: "PrecisionConfig", | ||
| w1_bias: torch.Tensor | None, | ||
| w2_bias: torch.Tensor | None, | ||
| num_experts: int, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| unfused: bool = False, | ||
| ) -> torch.Tensor: | ||
| quant_config = mxfp4_w4a16_moe_quant_config( | ||
| w1_bias=w1_bias, | ||
| w2_bias=w2_bias, | ||
| w1_scale=w1_scale, | ||
| w2_scale=w2_scale, | ||
| ) | ||
|
|
||
| if unfused: | ||
| fused_experts = UnfusedOAITritonExperts(quant_config) | ||
| else: | ||
| fused_experts = OAITritonExperts(quant_config) | ||
|
|
||
| mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts) | ||
|
|
||
| return mk.forward( | ||
| hidden_states=x, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| inplace=True, | ||
| activation="swigluoai", | ||
| global_num_experts=num_experts, | ||
| expert_map=None, | ||
| apply_router_weight_on_input=False, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." | ||
| ) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("m,n,k", MNK) | ||
| @pytest.mark.parametrize("num_experts", [32, 128]) | ||
| @pytest.mark.parametrize("topk", [4]) | ||
| @pytest.mark.parametrize("unfused", [True, False]) | ||
| def test_oai_triton_moe( | ||
| dtype: torch.dtype, | ||
| m: int, | ||
| n: int, | ||
| k: int, | ||
| num_experts: int, | ||
| topk: int, | ||
| unfused: bool, | ||
| ): | ||
| current_platform.seed_everything(0) | ||
| ( | ||
| w1, | ||
| w2, | ||
| w1_bias, | ||
| w2_bias, | ||
| w1_tri, | ||
| w2_tri, | ||
| w1_bias_tri, | ||
| w2_bias_tri, | ||
| w1_precision_config, | ||
| w2_precision_config, | ||
| ) = make_weights(dtype, k, n, num_experts) | ||
|
|
||
| x = torch.randn((m, k), dtype=dtype, device="cuda") | ||
| router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype) | ||
| topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True) | ||
| topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) | ||
|
|
||
| with set_current_vllm_config(VllmConfig()): | ||
| out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids) | ||
|
|
||
| out = oai_triton_moe_impl( | ||
| x, | ||
| w1_tri, | ||
| w2_tri, | ||
| w1_precision_config, | ||
| w2_precision_config, | ||
| w1_bias_tri, | ||
| w2_bias_tri, | ||
| num_experts, | ||
| topk_weights, | ||
| topk_ids, | ||
| unfused, | ||
| ) | ||
|
|
||
| assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,15 +20,24 @@ | |
| _get_config_dtype_str, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( | ||
| modular_marlin_fused_moe, | ||
| MarlinExperts, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import ( | ||
| modular_triton_fused_moe, | ||
| TritonExperts, | ||
| try_get_optimal_moe_config, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xyang16 can you add a unit test for gpt-oss lora + triton_kernels. The test can be predicated on
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added test_modular_oai_triton_moe.py. Thanks! |
||
| FusedMoEModularMethod, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( | ||
| UnfusedOAITritonExperts, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.modular_kernel import ( | ||
| FusedMoEModularKernel, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.prepare_finalize import ( | ||
| MoEPrepareAndFinalizeNoEP, | ||
| ) | ||
|
|
||
| from .utils import _get_lora_device | ||
|
|
||
|
|
@@ -114,15 +123,23 @@ def _inject_lora_into_fused_moe(self): | |
| self.base_layer.ensure_moe_quant_config_init() | ||
| quant_config = self.base_layer.quant_method.moe_quant_config | ||
|
|
||
| m_fused_moe_fn = ( | ||
| modular_triton_fused_moe( | ||
| quant_config, shared_experts=self.base_layer.shared_experts | ||
| prepare_finalize = MoEPrepareAndFinalizeNoEP() | ||
| m_fused_moe_fn = FusedMoEModularKernel( | ||
| prepare_finalize, | ||
| self.base_layer.quant_method.select_gemm_impl( | ||
| prepare_finalize, self.base_layer | ||
| ), | ||
| self.base_layer.shared_experts, | ||
| getattr(self.base_layer, "shared_experts_stream", None), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should assert below that we are getting the Experts that we are expecting for the different cases.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added assert
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Can you also assert for the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I see @jeejeelee 's comment about this here #28971 (comment) - @jeejeelee do you think an else assert like, makes sense ?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @varun-sundar-rabindranath makes sense
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added an else assert. Thanks! |
||
| ) | ||
| if quant_config.use_mxfp4_w4a16: | ||
| assert isinstance( | ||
| m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts) | ||
| ) | ||
| if not quant_config.use_mxfp4_w4a16 | ||
| else modular_marlin_fused_moe( | ||
| quant_config, shared_experts=self.base_layer.shared_experts | ||
| else: | ||
| assert isinstance( | ||
| m_fused_moe_fn.fused_experts, (MarlinExperts, TritonExperts) | ||
| ) | ||
| ) | ||
|
|
||
| def fwd_decorator(layer, func): | ||
| def wrapper(*args, **kwargs): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice tests ! Thanks @xyang16