diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py new file mode 100644 index 000000000000..3361d85e9250 --- /dev/null +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test modular OAI Triton MoE +""" + +import pytest +import torch + +from vllm.utils.import_utils import has_triton_kernels + +if not has_triton_kernels(): + pytest.skip( + "triton_kernels not found, skipping all related tests", + allow_module_level=True, + ) + +from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig +from triton_kernels.numerics import InFlexData +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp +from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor +from triton_kernels.tensor_details import layout +from triton_kernels.testing import assert_close + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + OAITritonExperts, + UnfusedOAITritonExperts, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) +from vllm.model_executor.layers.utils import shuffle_weight +from vllm.platforms import current_platform + +MNK = [ + (1, 512, 384), + (1, 2880, 2880), + (2, 512, 384), + (2, 2880, 2880), + (32, 2880, 2880), + (64, 2880, 2880), +] + + +def unshuffle_weight(w: torch.Tensor): + first = w[..., ::2] + second = w[..., 1::2] + return torch.concat((first, second), dim=-1) + + +def make_weights(dtype, k, n, e): + w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda") + w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda") + + w2 = torch.randn((e, n, k), dtype=dtype, device="cuda") + w2_bias = torch.randn((e, k), dtype=dtype, device="cuda") + + w1_tri = w1.clone() + w2_tri = w2.clone() + + w1_bias_tri = w1_bias.clone() + w2_bias_tri = w2_bias.clone() + w1_bias_tri = w1_bias_tri.to(torch.float32) + w2_bias_tri = w2_bias_tri.to(torch.float32) + + # shuffle weights + w1_tri = shuffle_weight(w1_tri) + w1_bias_tri = shuffle_weight(w1_bias_tri) + + # quant triton_weights + w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) + w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1) + w1 = unshuffle_weight(w1) + + w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) + w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1) + + num_warps = 8 + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + w_scale_layout, w_scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps) + ) + + w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts) + w1_scale_tri = convert_layout( + wrap_torch_tensor(w1_scale_tri), + w_scale_layout, + **w_scale_layout_opts, + ) + + w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts) + w2_scale_tri = convert_layout( + wrap_torch_tensor(w2_scale_tri), + w_scale_layout, + **w_scale_layout_opts, + ) + + w1_precision_config = PrecisionConfig( + weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) + ) + w2_precision_config = PrecisionConfig( + weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) + ) + + return ( + w1, + w2, + w1_bias, + w2_bias, + w1_tri, + w2_tri, + w1_bias_tri, + w2_bias_tri, + w1_precision_config, + w2_precision_config, + ) + + +def swiglu(x, alpha: float = 1.702, limit: float = 1.0): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + if limit is not None: + x_linear = x_linear.clamp(min=-limit, max=limit) + return out_glu * (x_linear + 1) + + +def torch_moe_impl( + hidden_states: torch.Tensor, # (M, K) + w1: torch.Tensor, # (E, K, 2N) + w2: torch.Tensor, # (E, N, K) + w1_bias: torch.Tensor, # (E, 2N) + w2_bias: torch.Tensor, # (E, K) + topk_weights: torch.Tensor, # (M, topk) + topk_ids: torch.Tensor, # (M, topk) +): + w1 = w1[topk_ids, ...] + w1_bias = w1_bias[topk_ids, ...] + hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias + hidden_states = swiglu(hidden_states, limit=7) + + w2 = w2[topk_ids, ...] + w2_bias = w2_bias[topk_ids, ...] + hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias + + # Weighted sum of experts + hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights) + return hidden_states + + +def oai_triton_moe_impl( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: "PrecisionConfig", + w2_scale: "PrecisionConfig", + w1_bias: torch.Tensor | None, + w2_bias: torch.Tensor | None, + num_experts: int, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + unfused: bool = False, +) -> torch.Tensor: + quant_config = mxfp4_w4a16_moe_quant_config( + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + + if unfused: + fused_experts = UnfusedOAITritonExperts(quant_config) + else: + fused_experts = OAITritonExperts(quant_config) + + mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts) + + return mk.forward( + hidden_states=x, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation="swigluoai", + global_num_experts=num_experts, + expert_map=None, + apply_router_weight_on_input=False, + ) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("m,n,k", MNK) +@pytest.mark.parametrize("num_experts", [32, 128]) +@pytest.mark.parametrize("topk", [4]) +@pytest.mark.parametrize("unfused", [True, False]) +def test_oai_triton_moe( + dtype: torch.dtype, + m: int, + n: int, + k: int, + num_experts: int, + topk: int, + unfused: bool, +): + current_platform.seed_everything(0) + ( + w1, + w2, + w1_bias, + w2_bias, + w1_tri, + w2_tri, + w1_bias_tri, + w2_bias_tri, + w1_precision_config, + w2_precision_config, + ) = make_weights(dtype, k, n, num_experts) + + x = torch.randn((m, k), dtype=dtype, device="cuda") + router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype) + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + + with set_current_vllm_config(VllmConfig()): + out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids) + + out = oai_triton_moe_impl( + x, + w1_tri, + w2_tri, + w1_precision_config, + w2_precision_config, + w1_bias_tri, + w2_bias_tri, + num_experts, + topk_weights, + topk_ids, + unfused, + ) + + assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 3ad19370962a..24cab79a7244 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -20,15 +20,24 @@ _get_config_dtype_str, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - modular_marlin_fused_moe, + MarlinExperts, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( - modular_triton_fused_moe, + TritonExperts, try_get_optimal_moe_config, ) from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, ) +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + UnfusedOAITritonExperts, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel, +) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from .utils import _get_lora_device @@ -114,15 +123,23 @@ def _inject_lora_into_fused_moe(self): self.base_layer.ensure_moe_quant_config_init() quant_config = self.base_layer.quant_method.moe_quant_config - m_fused_moe_fn = ( - modular_triton_fused_moe( - quant_config, shared_experts=self.base_layer.shared_experts + prepare_finalize = MoEPrepareAndFinalizeNoEP() + m_fused_moe_fn = FusedMoEModularKernel( + prepare_finalize, + self.base_layer.quant_method.select_gemm_impl( + prepare_finalize, self.base_layer + ), + self.base_layer.shared_experts, + getattr(self.base_layer, "shared_experts_stream", None), + ) + if quant_config.use_mxfp4_w4a16: + assert isinstance( + m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts) ) - if not quant_config.use_mxfp4_w4a16 - else modular_marlin_fused_moe( - quant_config, shared_experts=self.base_layer.shared_experts + else: + assert isinstance( + m_fused_moe_fn.fused_experts, (MarlinExperts, TritonExperts) ) - ) def fwd_decorator(layer, func): def wrapper(*args, **kwargs): diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 128507639fdf..0b006e15632e 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, @@ -376,3 +377,148 @@ def apply( intermediate_cache=workspace2, a1q_scale=a1q_scale, ) + + +class UnfusedOAITritonExperts(BaseOAITritonExperts): + """ + A Triton based MoE expert class that operates on expert standard + format and explicitly keeps the activation and reduction (moe_sum) steps + unfused from the matmul_ogs kernel. This exposes injection points + for activation and moe_sum. + + One use case for it is to inject LoRA modules on the activation and moe_sum. + """ + + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" + super().__init__(quant_config) + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_chunking(self) -> bool: + return True + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # workspace are allocated inside the kernel + workspace1 = (M * topk, N // 2) + workspace2 = (M * topk, max(N, K)) + output = (M, K) + return (workspace1, workspace2, output) + + def moe_sum(self, input: torch.Tensor, output: torch.Tensor): + ops.moe_sum(input, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + if self.quant_config is None: + self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + + if expert_map is not None: + topk_ids = expert_map[topk_ids] + + local_num_experts = w1.size(0) + if global_num_experts == -1: + global_num_experts = local_num_experts + + routing_data, gather_indx, scatter_indx = self._make_routing_data( + topk_ids, topk_weights, local_num_experts + ) + + topk = topk_ids.size(1) + + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert ( + self.quant_config.w1_bias is None + or self.quant_config.w1_bias.dtype == torch.float32 + ) + assert ( + self.quant_config.w2_bias is None + or self.quant_config.w2_bias.dtype == torch.float32 + ) + + # Shape check, only check non-mxfp4 + assert hidden_states.ndim == 2 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + batch_dim = 1 + M, K = hidden_states.shape + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + # Note that the output tensor might be in workspace13 + intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N)) + intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K)) + intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2)) + + gammas = routing_data.gate_scal if routing_data else None + + matmul_ogs( + hidden_states, + w1, + self.quant_config.w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=self.quant_config.w1_precision, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=None, + y=intermediate_cache1, + ) + + self.activation( + activation, intermediate_cache2, intermediate_cache1.view(-1, N) + ) + + # matmul_ogs grouped reduction fuse sum across multiple experts: + # y[dst_ind // n_expts_act, :] += x[src_ind, :] + # Need to set n_expts_act to 1 to unfuse moe_sum + routing_data.n_expts_act = 1 + + matmul_ogs( + intermediate_cache2, + w2, + self.quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=self.quant_config.w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=intermediate_cache3, + ) + + self.moe_sum(intermediate_cache3.view(-1, topk, K), output) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index d975131f7cff..6c888e97391a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -29,6 +29,7 @@ ) from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, + UnfusedOAITritonExperts, ) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod @@ -82,8 +83,21 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: if not current_platform.is_cuda(): return Mxfp4Backend.NONE - logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") - return Mxfp4Backend.MARLIN + # If FlashInfer is not available, try either Marlin or Triton + triton_kernels_supported = ( + has_triton_kernels() + and is_torch_equal_or_newer("2.8.0") + # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 + # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 + # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 + and (9, 0) <= current_platform.get_device_capability() < (11, 0) + ) + if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: + logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") + return Mxfp4Backend.MARLIN + + logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") + return Mxfp4Backend.TRITON def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: @@ -855,6 +869,8 @@ def select_gemm_impl( elif self.mxfp4_backend == Mxfp4Backend.MARLIN: return MarlinExperts(self.moe_quant_config) elif self.mxfp4_backend == Mxfp4Backend.TRITON: + if self.moe.is_lora_enabled: + return UnfusedOAITritonExperts(self.moe_quant_config) return OAITritonExperts(self.moe_quant_config) else: raise NotImplementedError(