From 59a83aeb9e129ddee2cbbba01078ab0451da6b23 Mon Sep 17 00:00:00 2001 From: Varun Date: Mon, 30 Jun 2025 09:20:12 -0700 Subject: [PATCH 1/2] fixes Signed-off-by: Varun --- .../layers/fused_moe/batched_deep_gemm_moe.py | 6 +- .../batched_triton_or_deep_gemm_moe.py | 4 +- .../layers/fused_moe/cutlass_moe.py | 8 +- .../layers/fused_moe/deep_gemm_moe.py | 2 +- .../fused_moe/deepep_ht_prepare_finalize.py | 26 ++++--- .../fused_moe/deepep_ll_prepare_finalize.py | 10 ++- .../layers/fused_moe/fused_batched_moe.py | 76 ++++++++----------- .../layers/fused_moe/fused_moe.py | 2 +- .../layers/fused_moe/modular_kernel.py | 42 +++++++--- .../layers/fused_moe/pplx_prepare_finalize.py | 10 ++- .../layers/fused_moe/prepare_finalize.py | 5 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 4 +- 12 files changed, 116 insertions(+), 79 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a8788e340fc8..22de5a026cf0 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -260,8 +260,11 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens + import deep_gemm as dg assert hidden_states.ndim == 3 assert self.block_shape is not None @@ -287,7 +290,6 @@ def apply( masked_m=expert_num_tokens, expected_m=expected_m) - assert expert_num_tokens is not None a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, expert_num_tokens) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 0d67b4a4a6d6..76adfed9ca1c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -129,7 +129,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) @@ -137,4 +137,4 @@ def apply( experts.apply(output, hidden_states, w1, w2, topk_ids, activation, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, - workspace2, expert_num_tokens) + workspace2, expert_tokens_meta) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d771a7a54cfc..2a675eee2a6e 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -303,11 +303,17 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + + expert_num_tokens = None + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + activation_callable = lambda o, i: self.activation(activation, o, i) + in_dtype = hidden_states.dtype run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 8ad57c237fed..c8c02497bb1f 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -119,7 +119,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): import deep_gemm as dg assert self.block_shape is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index b625c28d4070..8ed42975a32e 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -62,8 +62,9 @@ def _do_dispatch(self, tokens: torch.Tensor, has_scales = token_scales is not None - (num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, - is_token_in_rank, event) = self.buffer.get_dispatch_layout( + (num_tokens_per_rank, num_tokens_per_rdma_rank, + dispatch_expert_num_tokens, is_token_in_rank, + event) = self.buffer.get_dispatch_layout( topk_idx=rank_topk_ids, num_experts=num_experts, previous_event=None, @@ -83,7 +84,7 @@ def _do_dispatch(self, tokens: torch.Tensor, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=expert_num_tokens, + num_tokens_per_expert=dispatch_expert_num_tokens, topk_idx=rank_topk_ids, topk_weights=rank_topk_weights, # expert_alignment rounds the number of tokens per expert @@ -115,7 +116,13 @@ def _do_dispatch(self, tokens: torch.Tensor, num_experts - 1 if self.rank_expert_offset == 0 else 0, expert_topk_ids + self.rank_expert_offset) - return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + # Makes a GPU-CPU copy. + # TODO (varun): Maybe it is better to re-compute the expert_num_tokens + # on GPU. + expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( + expert_num_tokens_per_expert_list, device=expert_x.device) + + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) def prepare( @@ -129,8 +136,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -149,7 +157,7 @@ def prepare( ) if a1q_scale is not None and a1q_scale.numel() == 1: a1q_scale = a1q_scale.view(1, 1) - (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) = self._do_dispatch( tokens=a1q, token_scales=a1q_scale, @@ -159,7 +167,7 @@ def prepare( else: # DeepEP kernels only support dispatching per-token-quant # quantization. dispatch in bfloat16. - (expert_x, _, expert_num_tokens, expert_topk_ids, + (expert_x, _, expert_tokens_meta, expert_topk_ids, expert_topk_weights) = self._do_dispatch( tokens=a1, token_scales=None, @@ -176,7 +184,7 @@ def prepare( per_act_token_quant=False, block_shape=quant_config.block_shape) - return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) def _apply_weights_and_reduce(self, num_tokens: int, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 78ac4acc495d..38c33203abfb 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -119,8 +119,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: hidden_size = a1.size(1) assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ @@ -158,7 +159,10 @@ def prepare( expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) - return (expert_x, expert_x_scale, expert_num_tokens, None, None) + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + + return (expert_x, expert_x_scale, expert_tokens_meta, None, None) def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 0355abbf1d2b..8da048680d23 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -505,8 +505,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) @@ -587,7 +588,10 @@ def prepare( assert b_a1_scale is None or b_a1_scale.ndim == 3 - return b_a1, b_a1_scale, tokens_per_expert, None, None + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None) + + return b_a1, b_a1_scale, expert_tokens_meta, None, None def finalize( self, @@ -691,28 +695,19 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: else: return t.to(f32) * group_broadcast(scale, t.shape) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, + activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): assert hidden_states.dim() == 3 - assert expert_num_tokens is not None + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens num_local_experts = w1.size(0) assert num_local_experts == w1.size(0), ( @@ -895,26 +890,16 @@ def workspace_shapes( output = (num_experts, max_num_tokens * num_dp, K) return (workspace13, workspace2, output, a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, + activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( @@ -931,6 +916,9 @@ def apply( assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] + assert expert_tokens_meta is not None + + expert_num_tokens = expert_tokens_meta.expert_num_tokens E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( hidden_states, w1, w2, topk_ids) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d90..409db3260bc2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1598,7 +1598,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): # Check constraints. if self.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f332b5168913..b7c5515c9099 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from enum import Enum +from dataclasses import dataclass from math import prod from typing import Optional, final @@ -95,6 +96,25 @@ class FusedMoEActivationFormat(Enum): BatchedExperts = "batched_experts", +@dataclass +class ExpertTokensMetadata: + """ + Metadata regarding expert-token routing. + """ + expert_num_tokens: torch.Tensor + expert_num_tokens_cpu: Optional[torch.Tensor] + + @staticmethod + def make_from_list(expert_num_tokens_list: list[int], + device: str) -> "ExpertTokensMetadata": + expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list, + device="cpu", + dtype=torch.int32) + return ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens_cpu.to(device, + non_blocking=True), + expert_num_tokens_cpu=expert_num_tokens_cpu) + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -114,8 +134,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -134,7 +155,8 @@ def prepare( Returns a tuple of: - quantized + dispatched a. - quantized + dispatched a1_scales. - - Optional tensor as big as number of local experts that contains the + - Optional ExpertTokensMetadata containing gpu/cpu tensors + as big as the number of local experts with the information about the number of tokens assigned to each local expert. - Optional dispatched expert topk IDs - Optional dispatched expert topk weight @@ -318,7 +340,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], ): """ This function computes the intermediate result of a Mixture of Experts @@ -351,8 +373,10 @@ def apply( must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation function. - - expert_num_tokens: An optional tensor containing the number of tokens - assigned to each expert when using batched experts format input. + - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional + ExpertTokensMetadata object containing gpu/cpu tensors + as big as the number of local experts with the information about the + number of tokens assigned to each local expert. """ raise NotImplementedError @@ -458,7 +482,7 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts - (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, a1_scale, @@ -542,7 +566,7 @@ def forward( a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, - expert_num_tokens=expert_num_tokens, + expert_tokens_meta=expert_tokens_meta, ) else: # The leading output dimension may not be equal to M, so @@ -589,7 +613,7 @@ def forward( a2_scale=curr_a2_scale, workspace13=workspace13, workspace2=workspace2, - expert_num_tokens=expert_num_tokens, + expert_tokens_meta=expert_tokens_meta, ) self.prepare_finalize.finalize(output, fused_out, topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 112305a4f2d0..163b66b08abe 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -94,8 +94,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -198,7 +199,10 @@ def prepare( expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 - return expert_x, expert_x_scale, expert_num_tokens, None, None + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + + return expert_x, expert_x_scale, expert_tokens_meta, None, None def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index e1114efe5a3f..d413d2ce0e23 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -38,8 +38,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: if apply_router_weight_on_input: topk = topk_ids.size(1) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index e660376ebe6b..da8aafe83789 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -107,7 +107,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ): use_deep_gemm = (self.allow_deep_gemm and _valid_deep_gemm(hidden_states, w1, w2)) @@ -132,5 +132,5 @@ def apply( a2_scale, workspace13, workspace2, - expert_num_tokens, + expert_tokens_meta, ) From cff78ce9aeb299df3a5a92ad76789e9d48b57c64 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 8 Jul 2025 20:27:07 +0000 Subject: [PATCH 2/2] lint fixes Signed-off-by: Varun Sundar Rabindranath --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b7c5515c9099..29c232afd65e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from enum import Enum from dataclasses import dataclass +from enum import Enum from math import prod from typing import Optional, final @@ -115,6 +115,7 @@ def make_from_list(expert_num_tokens_list: list[int], non_blocking=True), expert_num_tokens_cpu=expert_num_tokens_cpu) + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """