Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
Expand Down Expand Up @@ -724,7 +725,7 @@ def test_fused_marlin_moe(
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)

marlin_output = torch.ops.vllm.fused_marlin_moe(
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
Expand Down Expand Up @@ -837,7 +838,7 @@ def test_fused_marlin_moe_with_bias(m):
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)

marlin_output = torch.ops.vllm.fused_marlin_moe(
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def get_config() -> dict[str, Any] | None:

if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
Expand Down
39 changes: 0 additions & 39 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
maybe_warn_marlin_atomic_add,
)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op


def fused_marlin_moe(
Expand Down Expand Up @@ -241,44 +240,6 @@ def fused_marlin_moe(
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)


def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
expert_map: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
g_idx2: torch.Tensor | None = None,
sort_indices1: torch.Tensor | None = None,
sort_indices2: torch.Tensor | None = None,
w1_zeros: torch.Tensor | None = None,
w2_zeros: torch.Tensor | None = None,
workspace: torch.Tensor | None = None,
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True,
output: torch.Tensor | None = None,
inplace: bool = False,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


direct_register_custom_op(
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
fake_impl=fused_marlin_moe_fake,
)


class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
Expand Down Expand Up @@ -604,7 +605,7 @@ def apply(
indices_type=self.topk_indices_dtype,
)

return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
Expand Down Expand Up @@ -462,7 +463,7 @@ def apply(
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down Expand Up @@ -1067,7 +1068,7 @@ def apply(
if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down Expand Up @@ -1654,7 +1655,7 @@ def apply(
indices_type=self.topk_indices_dtype,
)

return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
LinearBase,
Expand Down Expand Up @@ -1196,7 +1197,7 @@ def apply(
elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert self.fused_experts is None
result = torch.ops.vllm.fused_marlin_moe(
result = fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
Expand Down Expand Up @@ -765,7 +766,7 @@ def apply(
indices_type=self.topk_indices_dtype,
)

return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
Expand Down Expand Up @@ -1701,7 +1702,7 @@ def apply(
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
)
Expand Down Expand Up @@ -947,7 +950,7 @@ def apply(
e_score_correction_bias=e_score_correction_bias,
)

return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
fp8_w8a8_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
Expand Down Expand Up @@ -402,7 +403,7 @@ def apply(
)
if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down