Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 250 additions & 0 deletions tests/kernels/moe/test_modular_oai_triton_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test modular OAI Triton MoE
"""

import pytest
import torch

from vllm.utils.import_utils import has_triton_kernels

if not has_triton_kernels():
pytest.skip(
"triton_kernels not found, skipping all related tests",
allow_module_level=True,
)

from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.utils import shuffle_weight
from vllm.platforms import current_platform

MNK = [
(1, 512, 384),
(1, 2880, 2880),
(2, 512, 384),
(2, 2880, 2880),
(32, 2880, 2880),
(64, 2880, 2880),
]


def unshuffle_weight(w: torch.Tensor):
first = w[..., ::2]
second = w[..., 1::2]
return torch.concat((first, second), dim=-1)


def make_weights(dtype, k, n, e):
w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda")
w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda")

w2 = torch.randn((e, n, k), dtype=dtype, device="cuda")
w2_bias = torch.randn((e, k), dtype=dtype, device="cuda")

w1_tri = w1.clone()
w2_tri = w2.clone()

w1_bias_tri = w1_bias.clone()
w2_bias_tri = w2_bias.clone()
w1_bias_tri = w1_bias_tri.to(torch.float32)
w2_bias_tri = w2_bias_tri.to(torch.float32)

# shuffle weights
w1_tri = shuffle_weight(w1_tri)
w1_bias_tri = shuffle_weight(w1_bias_tri)

# quant triton_weights
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1)
w1 = unshuffle_weight(w1)

w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1)

num_warps = 8
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
w_scale_layout, w_scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
)

w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts)
w1_scale_tri = convert_layout(
wrap_torch_tensor(w1_scale_tri),
w_scale_layout,
**w_scale_layout_opts,
)

w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts)
w2_scale_tri = convert_layout(
wrap_torch_tensor(w2_scale_tri),
w_scale_layout,
**w_scale_layout_opts,
)

w1_precision_config = PrecisionConfig(
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
)
w2_precision_config = PrecisionConfig(
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
)

return (
w1,
w2,
w1_bias,
w2_bias,
w1_tri,
w2_tri,
w1_bias_tri,
w2_bias_tri,
w1_precision_config,
w2_precision_config,
)


def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
x_glu = x_glu.clamp(max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
if limit is not None:
x_linear = x_linear.clamp(min=-limit, max=limit)
return out_glu * (x_linear + 1)


def torch_moe_impl(
hidden_states: torch.Tensor, # (M, K)
w1: torch.Tensor, # (E, K, 2N)
w2: torch.Tensor, # (E, N, K)
w1_bias: torch.Tensor, # (E, 2N)
w2_bias: torch.Tensor, # (E, K)
topk_weights: torch.Tensor, # (M, topk)
topk_ids: torch.Tensor, # (M, topk)
):
w1 = w1[topk_ids, ...]
w1_bias = w1_bias[topk_ids, ...]
hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias
hidden_states = swiglu(hidden_states, limit=7)

w2 = w2[topk_ids, ...]
w2_bias = w2_bias[topk_ids, ...]
hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias

# Weighted sum of experts
hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights)
return hidden_states


def oai_triton_moe_impl(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: "PrecisionConfig",
w2_scale: "PrecisionConfig",
w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
num_experts: int,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
unfused: bool = False,
) -> torch.Tensor:
quant_config = mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)

if unfused:
fused_experts = UnfusedOAITritonExperts(quant_config)
else:
fused_experts = OAITritonExperts(quant_config)

mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)

return mk.forward(
hidden_states=x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation="swigluoai",
global_num_experts=num_experts,
expert_map=None,
apply_router_weight_on_input=False,
)


@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("m,n,k", MNK)
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("topk", [4])
@pytest.mark.parametrize("unfused", [True, False])
def test_oai_triton_moe(
dtype: torch.dtype,
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
unfused: bool,
):
current_platform.seed_everything(0)
(
w1,
w2,
w1_bias,
w2_bias,
w1_tri,
w2_tri,
w1_bias_tri,
w2_bias_tri,
w1_precision_config,
w2_precision_config,
) = make_weights(dtype, k, n, num_experts)

x = torch.randn((m, k), dtype=dtype, device="cuda")
router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)

with set_current_vllm_config(VllmConfig()):
out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids)

out = oai_triton_moe_impl(
x,
w1_tri,
w2_tri,
w1_precision_config,
w2_precision_config,
w1_bias_tri,
w2_bias_tri,
num_experts,
topk_weights,
topk_ids,
unfused,
)

assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests ! Thanks @xyang16

35 changes: 26 additions & 9 deletions vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,24 @@
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe,
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
TritonExperts,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xyang16 can you add a unit test for gpt-oss lora + triton_kernels. The test can be predicated on has_triton_kernels like in https://github.com/vllm-project/vllm/blob/main/tests/kernels/moe/test_gpt_oss_triton_kernels.py

Copy link
Contributor Author

@xyang16 xyang16 Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test_modular_oai_triton_moe.py. Thanks!

FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)

from .utils import _get_lora_device

Expand Down Expand Up @@ -114,15 +123,23 @@ def _inject_lora_into_fused_moe(self):
self.base_layer.ensure_moe_quant_config_init()
quant_config = self.base_layer.quant_method.moe_quant_config

m_fused_moe_fn = (
modular_triton_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
prepare_finalize = MoEPrepareAndFinalizeNoEP()
m_fused_moe_fn = FusedMoEModularKernel(
prepare_finalize,
self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer
),
self.base_layer.shared_experts,
getattr(self.base_layer, "shared_experts_stream", None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should assert below that we are getting the Experts that we are expecting for the different cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added assert

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can you also assert for the TritonExperts case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I see @jeejeelee 's comment about this here #28971 (comment) - @jeejeelee do you think an else assert like,

           assert isinstance(
                m_fused_moe_fn.fused_experts, (MarlinExperts,TritonExperts)
            )

makes sense ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added an else assert. Thanks!

)
if quant_config.use_mxfp4_w4a16:
assert isinstance(
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
)
if not quant_config.use_mxfp4_w4a16
else modular_marlin_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
else:
assert isinstance(
m_fused_moe_fn.fused_experts, (MarlinExperts, TritonExperts)
)
)

def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
Expand Down
Loading