diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py deleted file mode 100644 index 3361d85e9250..000000000000 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ /dev/null @@ -1,250 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Test modular OAI Triton MoE -""" - -import pytest -import torch - -from vllm.utils.import_utils import has_triton_kernels - -if not has_triton_kernels(): - pytest.skip( - "triton_kernels not found, skipping all related tests", - allow_module_level=True, - ) - -from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig -from triton_kernels.numerics import InFlexData -from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp -from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor -from triton_kernels.tensor_details import layout -from triton_kernels.testing import assert_close - -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config -from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - OAITritonExperts, - UnfusedOAITritonExperts, -) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) -from vllm.model_executor.layers.utils import shuffle_weight -from vllm.platforms import current_platform - -MNK = [ - (1, 512, 384), - (1, 2880, 2880), - (2, 512, 384), - (2, 2880, 2880), - (32, 2880, 2880), - (64, 2880, 2880), -] - - -def unshuffle_weight(w: torch.Tensor): - first = w[..., ::2] - second = w[..., 1::2] - return torch.concat((first, second), dim=-1) - - -def make_weights(dtype, k, n, e): - w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda") - w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda") - - w2 = torch.randn((e, n, k), dtype=dtype, device="cuda") - w2_bias = torch.randn((e, k), dtype=dtype, device="cuda") - - w1_tri = w1.clone() - w2_tri = w2.clone() - - w1_bias_tri = w1_bias.clone() - w2_bias_tri = w2_bias.clone() - w1_bias_tri = w1_bias_tri.to(torch.float32) - w2_bias_tri = w2_bias_tri.to(torch.float32) - - # shuffle weights - w1_tri = shuffle_weight(w1_tri) - w1_bias_tri = shuffle_weight(w1_bias_tri) - - # quant triton_weights - w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) - w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1) - w1 = unshuffle_weight(w1) - - w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) - w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1) - - num_warps = 8 - w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) - w_scale_layout, w_scale_layout_opts = ( - layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps) - ) - - w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts) - w1_scale_tri = convert_layout( - wrap_torch_tensor(w1_scale_tri), - w_scale_layout, - **w_scale_layout_opts, - ) - - w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts) - w2_scale_tri = convert_layout( - wrap_torch_tensor(w2_scale_tri), - w_scale_layout, - **w_scale_layout_opts, - ) - - w1_precision_config = PrecisionConfig( - weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) - ) - w2_precision_config = PrecisionConfig( - weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) - ) - - return ( - w1, - w2, - w1_bias, - w2_bias, - w1_tri, - w2_tri, - w1_bias_tri, - w2_bias_tri, - w1_precision_config, - w2_precision_config, - ) - - -def swiglu(x, alpha: float = 1.702, limit: float = 1.0): - # Note we add an extra bias of 1 to the linear layer - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - if limit is not None: - x_glu = x_glu.clamp(max=limit) - out_glu = x_glu * torch.sigmoid(alpha * x_glu) - if limit is not None: - x_linear = x_linear.clamp(min=-limit, max=limit) - return out_glu * (x_linear + 1) - - -def torch_moe_impl( - hidden_states: torch.Tensor, # (M, K) - w1: torch.Tensor, # (E, K, 2N) - w2: torch.Tensor, # (E, N, K) - w1_bias: torch.Tensor, # (E, 2N) - w2_bias: torch.Tensor, # (E, K) - topk_weights: torch.Tensor, # (M, topk) - topk_ids: torch.Tensor, # (M, topk) -): - w1 = w1[topk_ids, ...] - w1_bias = w1_bias[topk_ids, ...] - hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias - hidden_states = swiglu(hidden_states, limit=7) - - w2 = w2[topk_ids, ...] - w2_bias = w2_bias[topk_ids, ...] - hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias - - # Weighted sum of experts - hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights) - return hidden_states - - -def oai_triton_moe_impl( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: "PrecisionConfig", - w2_scale: "PrecisionConfig", - w1_bias: torch.Tensor | None, - w2_bias: torch.Tensor | None, - num_experts: int, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - unfused: bool = False, -) -> torch.Tensor: - quant_config = mxfp4_w4a16_moe_quant_config( - w1_bias=w1_bias, - w2_bias=w2_bias, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) - - if unfused: - fused_experts = UnfusedOAITritonExperts(quant_config) - else: - fused_experts = OAITritonExperts(quant_config) - - mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts) - - return mk.forward( - hidden_states=x, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation="swigluoai", - global_num_experts=num_experts, - expert_map=None, - apply_router_weight_on_input=False, - ) - - -@pytest.mark.skipif( - not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." -) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("m,n,k", MNK) -@pytest.mark.parametrize("num_experts", [32, 128]) -@pytest.mark.parametrize("topk", [4]) -@pytest.mark.parametrize("unfused", [True, False]) -def test_oai_triton_moe( - dtype: torch.dtype, - m: int, - n: int, - k: int, - num_experts: int, - topk: int, - unfused: bool, -): - current_platform.seed_everything(0) - ( - w1, - w2, - w1_bias, - w2_bias, - w1_tri, - w2_tri, - w1_bias_tri, - w2_bias_tri, - w1_precision_config, - w2_precision_config, - ) = make_weights(dtype, k, n, num_experts) - - x = torch.randn((m, k), dtype=dtype, device="cuda") - router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype) - topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True) - topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) - - with set_current_vllm_config(VllmConfig()): - out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids) - - out = oai_triton_moe_impl( - x, - w1_tri, - w2_tri, - w1_precision_config, - w2_precision_config, - w1_bias_tri, - w2_bias_tri, - num_experts, - topk_weights, - topk_ids, - unfused, - ) - - assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 24cab79a7244..3ad19370962a 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -20,24 +20,15 @@ _get_config_dtype_str, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, + modular_marlin_fused_moe, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, + modular_triton_fused_moe, try_get_optimal_moe_config, ) from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, ) -from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - UnfusedOAITritonExperts, -) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel, -) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from .utils import _get_lora_device @@ -123,23 +114,15 @@ def _inject_lora_into_fused_moe(self): self.base_layer.ensure_moe_quant_config_init() quant_config = self.base_layer.quant_method.moe_quant_config - prepare_finalize = MoEPrepareAndFinalizeNoEP() - m_fused_moe_fn = FusedMoEModularKernel( - prepare_finalize, - self.base_layer.quant_method.select_gemm_impl( - prepare_finalize, self.base_layer - ), - self.base_layer.shared_experts, - getattr(self.base_layer, "shared_experts_stream", None), - ) - if quant_config.use_mxfp4_w4a16: - assert isinstance( - m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts) + m_fused_moe_fn = ( + modular_triton_fused_moe( + quant_config, shared_experts=self.base_layer.shared_experts ) - else: - assert isinstance( - m_fused_moe_fn.fused_experts, (MarlinExperts, TritonExperts) + if not quant_config.use_mxfp4_w4a16 + else modular_marlin_fused_moe( + quant_config, shared_experts=self.base_layer.shared_experts ) + ) def fwd_decorator(layer, func): def wrapper(*args, **kwargs): diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 0b006e15632e..128507639fdf 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -5,7 +5,6 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, @@ -377,148 +376,3 @@ def apply( intermediate_cache=workspace2, a1q_scale=a1q_scale, ) - - -class UnfusedOAITritonExperts(BaseOAITritonExperts): - """ - A Triton based MoE expert class that operates on expert standard - format and explicitly keeps the activation and reduction (moe_sum) steps - unfused from the matmul_ogs kernel. This exposes injection points - for activation and moe_sum. - - One use case for it is to inject LoRA modules on the activation and moe_sum. - """ - - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) - - def supports_chunking(self) -> bool: - return True - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # workspace are allocated inside the kernel - workspace1 = (M * topk, N // 2) - workspace2 = (M * topk, max(N, K)) - output = (M, K) - return (workspace1, workspace2, output) - - def moe_sum(self, input: torch.Tensor, output: torch.Tensor): - ops.moe_sum(input, output) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - if self.quant_config is None: - self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - - if expert_map is not None: - topk_ids = expert_map[topk_ids] - - local_num_experts = w1.size(0) - if global_num_experts == -1: - global_num_experts = local_num_experts - - routing_data, gather_indx, scatter_indx = self._make_routing_data( - topk_ids, topk_weights, local_num_experts - ) - - topk = topk_ids.size(1) - - # type check, uint8 means mxfp4 - assert hidden_states.dtype == torch.bfloat16 - assert ( - self.quant_config.w1_bias is None - or self.quant_config.w1_bias.dtype == torch.float32 - ) - assert ( - self.quant_config.w2_bias is None - or self.quant_config.w2_bias.dtype == torch.float32 - ) - - # Shape check, only check non-mxfp4 - assert hidden_states.ndim == 2 - assert hidden_states.shape[-1] == w1.shape[-2] - assert w2.shape[-1] == w1.shape[1] - - batch_dim = 1 - M, K = hidden_states.shape - E, _, N = w1.shape - - if global_num_experts == -1: - global_num_experts = E - - # Note that the output tensor might be in workspace13 - intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N)) - intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K)) - intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2)) - - gammas = routing_data.gate_scal if routing_data else None - - matmul_ogs( - hidden_states, - w1, - self.quant_config.w1_bias, - routing_data, - gather_indx=gather_indx, - precision_config=self.quant_config.w1_precision, - gammas=gammas if apply_router_weight_on_input else None, - fused_activation=None, - y=intermediate_cache1, - ) - - self.activation( - activation, intermediate_cache2, intermediate_cache1.view(-1, N) - ) - - # matmul_ogs grouped reduction fuse sum across multiple experts: - # y[dst_ind // n_expts_act, :] += x[src_ind, :] - # Need to set n_expts_act to 1 to unfuse moe_sum - routing_data.n_expts_act = 1 - - matmul_ogs( - intermediate_cache2, - w2, - self.quant_config.w2_bias, - routing_data, - scatter_indx=scatter_indx, - precision_config=self.quant_config.w2_precision, - gammas=None if apply_router_weight_on_input else gammas, - y=intermediate_cache3, - ) - - self.moe_sum(intermediate_cache3.view(-1, topk, K), output) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 74036753496d..bc241ac692e2 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -30,7 +30,6 @@ ) from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, - UnfusedOAITritonExperts, ) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod @@ -84,21 +83,8 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: if not current_platform.is_cuda(): return Mxfp4Backend.NONE - # If FlashInfer is not available, try either Marlin or Triton - triton_kernels_supported = ( - has_triton_kernels() - and is_torch_equal_or_newer("2.8.0") - # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 - # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 - # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 - and (9, 0) <= current_platform.get_device_capability() < (11, 0) - ) - if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: - logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") - return Mxfp4Backend.MARLIN - - logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") - return Mxfp4Backend.TRITON + logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") + return Mxfp4Backend.MARLIN def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: @@ -868,8 +854,6 @@ def select_gemm_impl( elif self.mxfp4_backend == Mxfp4Backend.MARLIN: return MarlinExperts(self.moe_quant_config) elif self.mxfp4_backend == Mxfp4Backend.TRITON: - if self.moe.is_lora_enabled: - return UnfusedOAITritonExperts(self.moe_quant_config) return OAITritonExperts(self.moe_quant_config) else: raise NotImplementedError(