From fe752742904345040de1cbb56489203c9ae2f771 Mon Sep 17 00:00:00 2001 From: hsliu Date: Sat, 19 Jul 2025 11:31:19 +0800 Subject: [PATCH 1/3] [Feature][EPLB] Add support for unquantized models Signed-off-by: hsliu --- vllm/model_executor/layers/fused_moe/layer.py | 53 ++++++++++++++----- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4b8a37fcc738..06b6d3c02c24 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -375,8 +375,10 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) return self.forward( x=x, @@ -393,7 +395,12 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count + ) def forward_cuda( self, @@ -412,7 +419,16 @@ def forward_cuda( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -425,7 +441,13 @@ def forward_cuda( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count + ) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts( @@ -730,16 +752,19 @@ def __init__( if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) - if not isinstance(quant_method, Fp8MoEMethod): - # TODO: Add support for additional quantization methods. - # The implementation for other quantization methods does not - # contain essential differences, but the current quant API - # design causes duplicated work when extending to new - # quantization methods, so I'm leaving it for now. - # If you plan to add support for more quantization methods, - # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError("EPLB is only supported for FP8 " - "quantization for now.") + + # TODO: Add support for additional quantization methods. + SUPPORTED_MOE_QUANT_METHODS = { + UnquantizedFusedMoEMethod, + Fp8MoEMethod, + } + quant_method_type = type(quant_method) + + if quant_method_type not in SUPPORTED_MOE_QUANT_METHODS: + raise NotImplementedError( + "EPLB is only supported the following quantization methods: " + f"{[cls.__name__ for cls in SUPPORTED_MOE_QUANT_METHODS]}" + ) moe_quant_params = { "num_experts": self.local_num_experts, From a1c1f3b5da1fa5a4a4d15a423ee1abee13f8b47e Mon Sep 17 00:00:00 2001 From: hsliu Date: Sat, 19 Jul 2025 11:37:49 +0800 Subject: [PATCH 2/3] remove assertations Signed-off-by: hsliu --- vllm/model_executor/layers/fused_moe/layer.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 07ac8a985dc7..8a4cb411b578 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -372,10 +372,14 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) + if expert_load_view is None: + raise ValueError("expert_load_view must be provided when enable_eplb is True") + if logical_to_physical_map is None: + raise ValueError("logical_to_physical_map must be provided when enable_eplb is True") + if logical_replica_count is None: + raise ValueError("logical_replica_count must be provided when enable_eplb is True") + if not isinstance(layer, FusedMoE): + raise TypeError(f"Expected layer to be FusedMoE, but got {type(layer)}") return self.forward( x=x, @@ -422,10 +426,14 @@ def forward_cuda( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) + if expert_load_view is None: + raise ValueError("expert_load_view must be provided when enable_eplb is True") + if logical_to_physical_map is None: + raise ValueError("logical_to_physical_map must be provided when enable_eplb is True") + if logical_replica_count is None: + raise ValueError("logical_replica_count must be provided when enable_eplb is True") + if not isinstance(layer, FusedMoE): + raise TypeError(f"Expected layer to be FusedMoE, but got {type(layer)}") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, From 7d56ab5f4c88374f441ec3c70de2b5844487aae6 Mon Sep 17 00:00:00 2001 From: hsliu Date: Mon, 21 Jul 2025 13:36:39 +0800 Subject: [PATCH 3/3] Fix line length issues in fused_moe layer.py Signed-off-by: hsliu --- vllm/model_executor/layers/fused_moe/layer.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8a4cb411b578..ee3d041464d7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -373,13 +373,17 @@ def apply( ) -> torch.Tensor: if enable_eplb: if expert_load_view is None: - raise ValueError("expert_load_view must be provided when enable_eplb is True") + raise ValueError( + "expert_load_view must be provided when enable_eplb is True") if logical_to_physical_map is None: - raise ValueError("logical_to_physical_map must be provided when enable_eplb is True") + raise ValueError( + "logical_to_physical_map must be provided when enable_eplb is True") if logical_replica_count is None: - raise ValueError("logical_replica_count must be provided when enable_eplb is True") + raise ValueError( + "logical_replica_count must be provided when enable_eplb is True") if not isinstance(layer, FusedMoE): - raise TypeError(f"Expected layer to be FusedMoE, but got {type(layer)}") + raise TypeError( + f"Expected layer to be FusedMoE, but got {type(layer)}") return self.forward( x=x, @@ -427,13 +431,17 @@ def forward_cuda( ) -> torch.Tensor: if enable_eplb: if expert_load_view is None: - raise ValueError("expert_load_view must be provided when enable_eplb is True") + raise ValueError( + "expert_load_view must be provided when enable_eplb is True") if logical_to_physical_map is None: - raise ValueError("logical_to_physical_map must be provided when enable_eplb is True") + raise ValueError( + "logical_to_physical_map must be provided when enable_eplb is True") if logical_replica_count is None: - raise ValueError("logical_replica_count must be provided when enable_eplb is True") + raise ValueError( + "logical_replica_count must be provided when enable_eplb is True") if not isinstance(layer, FusedMoE): - raise TypeError(f"Expected layer to be FusedMoE, but got {type(layer)}") + raise TypeError( + f"Expected layer to be FusedMoE, but got {type(layer)}") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -1424,10 +1432,10 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): >= chunk_size) assert (self.batched_router_logits.size(0) # type: ignore >= chunk_size) - staged_hidden_states = self.batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = self.batched_router_logits[: - chunk_size, :] # type: ignore + staged_hidden_states = self.batched_hidden_states[ + :chunk_size, :] # type: ignore + staged_router_logits = self.batched_router_logits[ + :chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)