diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 6b391c173f0b..966e2f8f3b13 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -26,6 +26,7 @@ int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe, @@ -724,7 +725,7 @@ def test_fused_marlin_moe( with set_current_vllm_config(vllm_config): torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) - marlin_output = torch.ops.vllm.fused_marlin_moe( + marlin_output = fused_marlin_moe( a, qweight1, qweight2, @@ -837,7 +838,7 @@ def test_fused_marlin_moe_with_bias(m): with set_current_vllm_config(vllm_config): torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2) - marlin_output = torch.ops.vllm.fused_marlin_moe( + marlin_output = fused_marlin_moe( a, qweight1, qweight2, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 247919dcc844..cb31045971bd 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -51,7 +51,6 @@ def get_config() -> dict[str, Any] | None: if HAS_TRITON: # import to register the custom ops - import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 58ed826ba037..57e17f324d2e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -19,7 +19,6 @@ maybe_warn_marlin_atomic_add, ) from vllm.scalar_type import ScalarType, scalar_types -from vllm.utils import direct_register_custom_op def fused_marlin_moe( @@ -241,44 +240,6 @@ def fused_marlin_moe( return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) -def fused_marlin_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor | None, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_type_id: int, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - global_scale1: torch.Tensor | None = None, - global_scale2: torch.Tensor | None = None, - expert_map: torch.Tensor | None = None, - g_idx1: torch.Tensor | None = None, - g_idx2: torch.Tensor | None = None, - sort_indices1: torch.Tensor | None = None, - sort_indices2: torch.Tensor | None = None, - w1_zeros: torch.Tensor | None = None, - w2_zeros: torch.Tensor | None = None, - workspace: torch.Tensor | None = None, - intermediate_cache13: torch.Tensor | None = None, - intermediate_cache2: torch.Tensor | None = None, - is_k_full: bool = True, - output: torch.Tensor | None = None, - inplace: bool = False, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -direct_register_custom_op( - op_name="fused_marlin_moe", - op_func=fused_marlin_moe, - fake_impl=fused_marlin_moe_fake, -) - - class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): # TODO (varun) : Enable activation quantization diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index e1633d392dbf..d96c657e0119 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -14,6 +14,7 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -604,7 +605,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 28383491207e..315356b474d9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -34,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, @@ -462,7 +463,7 @@ def apply( # if self.use_marlin: assert self.fused_experts is None - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, @@ -1067,7 +1068,7 @@ def apply( if self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." assert self.fused_experts is None - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, @@ -1654,7 +1655,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9a03105fafbf..02b1896a8996 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -26,6 +26,7 @@ FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.linear import ( LinearBase, @@ -1196,7 +1197,7 @@ def apply( elif self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." assert self.fused_experts is None - result = torch.ops.vllm.fused_marlin_moe( + result = fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index dd86c990259f..b22c3c125ead 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -15,6 +15,7 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -765,7 +766,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 7c7769455e8a..0f0638899bf1 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -1701,7 +1702,7 @@ def apply( # if self.use_marlin: assert self.fused_experts is None - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 5d78b82e3ee7..a7f9fdcb5513 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -21,7 +21,10 @@ mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + fused_marlin_moe, +) from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, ) @@ -947,7 +950,7 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 778317e3a959..c13cf7007e68 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -20,6 +20,7 @@ fp8_w8a8_moe_quant_config, ocp_mx_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, ) @@ -402,7 +403,7 @@ def apply( ) if self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight,