@@ -375,8 +375,10 @@ def apply(
375375 logical_replica_count : Optional [torch .Tensor ] = None ,
376376 ) -> torch .Tensor :
377377 if enable_eplb :
378- raise NotImplementedError (
379- "EPLB not supported for `UnquantizedFusedMoEMethod` yet." )
378+ assert expert_load_view is not None
379+ assert logical_to_physical_map is not None
380+ assert logical_replica_count is not None
381+ assert isinstance (layer , FusedMoE )
380382
381383 return self .forward (
382384 x = x ,
@@ -393,7 +395,12 @@ def apply(
393395 scoring_func = scoring_func ,
394396 e_score_correction_bias = e_score_correction_bias ,
395397 activation = activation ,
396- apply_router_weight_on_input = apply_router_weight_on_input )
398+ apply_router_weight_on_input = apply_router_weight_on_input ,
399+ enable_eplb = enable_eplb ,
400+ expert_load_view = expert_load_view ,
401+ logical_to_physical_map = logical_to_physical_map ,
402+ logical_replica_count = logical_replica_count
403+ )
397404
398405 def forward_cuda (
399406 self ,
@@ -412,7 +419,16 @@ def forward_cuda(
412419 e_score_correction_bias : Optional [torch .Tensor ] = None ,
413420 apply_router_weight_on_input : bool = False ,
414421 activation : str = "silu" ,
422+ enable_eplb : bool = False ,
423+ expert_load_view : Optional [torch .Tensor ] = None ,
424+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
425+ logical_replica_count : Optional [torch .Tensor ] = None ,
415426 ) -> torch .Tensor :
427+ if enable_eplb :
428+ assert expert_load_view is not None
429+ assert logical_to_physical_map is not None
430+ assert logical_replica_count is not None
431+ assert isinstance (layer , FusedMoE )
416432
417433 topk_weights , topk_ids = FusedMoE .select_experts (
418434 hidden_states = x ,
@@ -425,7 +441,13 @@ def forward_cuda(
425441 custom_routing_function = custom_routing_function ,
426442 scoring_func = scoring_func ,
427443 e_score_correction_bias = e_score_correction_bias ,
428- indices_type = self .topk_indices_dtype )
444+ indices_type = self .topk_indices_dtype ,
445+ enable_eplb = enable_eplb ,
446+ expert_map = expert_map ,
447+ expert_load_view = expert_load_view ,
448+ logical_to_physical_map = logical_to_physical_map ,
449+ logical_replica_count = logical_replica_count
450+ )
429451
430452 if self .rocm_aiter_moe_enabled :
431453 return self .rocm_aiter_fused_experts (
@@ -730,16 +752,19 @@ def __init__(
730752 if self .enable_eplb :
731753 from vllm .model_executor .layers .quantization .fp8 import (
732754 Fp8MoEMethod )
733- if not isinstance (quant_method , Fp8MoEMethod ):
734- # TODO: Add support for additional quantization methods.
735- # The implementation for other quantization methods does not
736- # contain essential differences, but the current quant API
737- # design causes duplicated work when extending to new
738- # quantization methods, so I'm leaving it for now.
739- # If you plan to add support for more quantization methods,
740- # please refer to the implementation in `Fp8MoEMethod`.
741- raise NotImplementedError ("EPLB is only supported for FP8 "
742- "quantization for now." )
755+
756+ # TODO: Add support for additional quantization methods.
757+ SUPPORTED_MOE_QUANT_METHODS = {
758+ UnquantizedFusedMoEMethod ,
759+ Fp8MoEMethod ,
760+ }
761+ quant_method_type = type (quant_method )
762+
763+ if quant_method_type not in SUPPORTED_MOE_QUANT_METHODS :
764+ raise NotImplementedError (
765+ "EPLB is only supported the following quantization methods: "
766+ f"{ [cls .__name__ for cls in SUPPORTED_MOE_QUANT_METHODS ]} "
767+ )
743768
744769 moe_quant_params = {
745770 "num_experts" : self .local_num_experts ,
0 commit comments