Skip to content

Commit 8ae1692

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[torch.compile] Unwrap fused_marlin_moe custom op (#26739)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 8a0af6a commit 8ae1692

File tree

10 files changed

+22
-52
lines changed

10 files changed

+22
-52
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
int4_w4a16_moe_quant_config,
2727
int8_w8a16_moe_quant_config,
2828
)
29+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
2930
from vllm.model_executor.layers.fused_moe.fused_moe import (
3031
fused_topk,
3132
modular_triton_fused_moe,
@@ -724,7 +725,7 @@ def test_fused_marlin_moe(
724725
with set_current_vllm_config(vllm_config):
725726
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
726727

727-
marlin_output = torch.ops.vllm.fused_marlin_moe(
728+
marlin_output = fused_marlin_moe(
728729
a,
729730
qweight1,
730731
qweight2,
@@ -837,7 +838,7 @@ def test_fused_marlin_moe_with_bias(m):
837838
with set_current_vllm_config(vllm_config):
838839
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
839840

840-
marlin_output = torch.ops.vllm.fused_marlin_moe(
841+
marlin_output = fused_marlin_moe(
841842
a,
842843
qweight1,
843844
qweight2,

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def get_config() -> dict[str, Any] | None:
5151

5252
if HAS_TRITON:
5353
# import to register the custom ops
54-
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
5554
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
5655
BatchedDeepGemmExperts,
5756
)

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
maybe_warn_marlin_atomic_add,
2020
)
2121
from vllm.scalar_type import ScalarType, scalar_types
22-
from vllm.utils import direct_register_custom_op
2322

2423

2524
def fused_marlin_moe(
@@ -241,44 +240,6 @@ def fused_marlin_moe(
241240
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
242241

243242

244-
def fused_marlin_moe_fake(
245-
hidden_states: torch.Tensor,
246-
w1: torch.Tensor,
247-
w2: torch.Tensor,
248-
w1_scale: torch.Tensor,
249-
w2_scale: torch.Tensor,
250-
gating_output: torch.Tensor | None,
251-
topk_weights: torch.Tensor,
252-
topk_ids: torch.Tensor,
253-
quant_type_id: int,
254-
apply_router_weight_on_input: bool = False,
255-
global_num_experts: int = -1,
256-
global_scale1: torch.Tensor | None = None,
257-
global_scale2: torch.Tensor | None = None,
258-
expert_map: torch.Tensor | None = None,
259-
g_idx1: torch.Tensor | None = None,
260-
g_idx2: torch.Tensor | None = None,
261-
sort_indices1: torch.Tensor | None = None,
262-
sort_indices2: torch.Tensor | None = None,
263-
w1_zeros: torch.Tensor | None = None,
264-
w2_zeros: torch.Tensor | None = None,
265-
workspace: torch.Tensor | None = None,
266-
intermediate_cache13: torch.Tensor | None = None,
267-
intermediate_cache2: torch.Tensor | None = None,
268-
is_k_full: bool = True,
269-
output: torch.Tensor | None = None,
270-
inplace: bool = False,
271-
) -> torch.Tensor:
272-
return torch.empty_like(hidden_states)
273-
274-
275-
direct_register_custom_op(
276-
op_name="fused_marlin_moe",
277-
op_func=fused_marlin_moe,
278-
fake_impl=fused_marlin_moe_fake,
279-
)
280-
281-
282243
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
283244
def __init__(self, quant_config: FusedMoEQuantConfig):
284245
# TODO (varun) : Enable activation quantization

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
FusedMoEConfig,
1515
FusedMoEQuantConfig,
1616
)
17+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
1718
from vllm.model_executor.layers.fused_moe.layer import (
1819
FusedMoE,
1920
FusedMoEMethodBase,
@@ -604,7 +605,7 @@ def apply(
604605
indices_type=self.topk_indices_dtype,
605606
)
606607

607-
return torch.ops.vllm.fused_marlin_moe(
608+
return fused_marlin_moe(
608609
x,
609610
layer.w13_qweight,
610611
layer.w2_qweight,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
3535
is_valid_flashinfer_cutlass_fused_moe,
3636
)
37+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
3738
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
3839
WNA16_SUPPORTED_BITS,
3940
WNA16_SUPPORTED_TYPES_MAP,
@@ -462,7 +463,7 @@ def apply(
462463
#
463464
if self.use_marlin:
464465
assert self.fused_experts is None
465-
return torch.ops.vllm.fused_marlin_moe(
466+
return fused_marlin_moe(
466467
x,
467468
layer.w13_weight,
468469
layer.w2_weight,
@@ -1067,7 +1068,7 @@ def apply(
10671068
if self.use_marlin:
10681069
assert activation == "silu", f"{activation} not supported for Marlin MoE."
10691070
assert self.fused_experts is None
1070-
return torch.ops.vllm.fused_marlin_moe(
1071+
return fused_marlin_moe(
10711072
x,
10721073
layer.w13_weight,
10731074
layer.w2_weight,
@@ -1654,7 +1655,7 @@ def apply(
16541655
indices_type=self.topk_indices_dtype,
16551656
)
16561657

1657-
return torch.ops.vllm.fused_marlin_moe(
1658+
return fused_marlin_moe(
16581659
x,
16591660
layer.w13_weight_packed,
16601661
layer.w2_weight_packed,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
FusedMoEQuantConfig,
2727
fp8_w8a8_moe_quant_config,
2828
)
29+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
2930
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
3031
from vllm.model_executor.layers.linear import (
3132
LinearBase,
@@ -1196,7 +1197,7 @@ def apply(
11961197
elif self.use_marlin:
11971198
assert activation == "silu", f"{activation} not supported for Marlin MoE."
11981199
assert self.fused_experts is None
1199-
result = torch.ops.vllm.fused_marlin_moe(
1200+
result = fused_marlin_moe(
12001201
x,
12011202
layer.w13_weight,
12021203
layer.w2_weight,

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FusedMoEConfig,
1616
FusedMoEQuantConfig,
1717
)
18+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
1819
from vllm.model_executor.layers.fused_moe.layer import (
1920
FusedMoE,
2021
FusedMoEMethodBase,
@@ -765,7 +766,7 @@ def apply(
765766
indices_type=self.topk_indices_dtype,
766767
)
767768

768-
return torch.ops.vllm.fused_marlin_moe(
769+
return fused_marlin_moe(
769770
x,
770771
layer.w13_qweight,
771772
layer.w2_qweight,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
2222
is_valid_flashinfer_cutlass_fused_moe,
2323
)
24+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
2425
from vllm.model_executor.layers.fused_moe.layer import (
2526
FusedMoE,
2627
FusedMoEMethodBase,
@@ -1701,7 +1702,7 @@ def apply(
17011702
#
17021703
if self.use_marlin:
17031704
assert self.fused_experts is None
1704-
return torch.ops.vllm.fused_marlin_moe(
1705+
return fused_marlin_moe(
17051706
x,
17061707
layer.w13_weight,
17071708
layer.w2_weight,

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
mxfp4_w4a16_moe_quant_config,
2222
ocp_mx_moe_quant_config,
2323
)
24-
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
24+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
25+
MarlinExperts,
26+
fused_marlin_moe,
27+
)
2528
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
2629
OAITritonExperts,
2730
)
@@ -947,7 +950,7 @@ def apply(
947950
e_score_correction_bias=e_score_correction_bias,
948951
)
949952

950-
return torch.ops.vllm.fused_marlin_moe(
953+
return fused_marlin_moe(
951954
x,
952955
layer.w13_weight,
953956
layer.w2_weight,

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
fp8_w8a8_moe_quant_config,
2121
ocp_mx_moe_quant_config,
2222
)
23+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
2324
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
2425
is_rocm_aiter_moe_enabled,
2526
)
@@ -402,7 +403,7 @@ def apply(
402403
)
403404
if self.use_marlin:
404405
assert activation == "silu", f"{activation} not supported for Marlin MoE."
405-
return torch.ops.vllm.fused_marlin_moe(
406+
return fused_marlin_moe(
406407
x,
407408
layer.w13_weight,
408409
layer.w2_weight,

0 commit comments

Comments
 (0)