Skip to content

Commit fe75274

Browse files
committed
[Feature][EPLB] Add support for unquantized models
Signed-off-by: hsliu <[email protected]>
1 parent 1bf6513 commit fe75274

File tree

1 file changed

+39
-14
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+39
-14
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,10 @@ def apply(
375375
logical_replica_count: Optional[torch.Tensor] = None,
376376
) -> torch.Tensor:
377377
if enable_eplb:
378-
raise NotImplementedError(
379-
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
378+
assert expert_load_view is not None
379+
assert logical_to_physical_map is not None
380+
assert logical_replica_count is not None
381+
assert isinstance(layer, FusedMoE)
380382

381383
return self.forward(
382384
x=x,
@@ -393,7 +395,12 @@ def apply(
393395
scoring_func=scoring_func,
394396
e_score_correction_bias=e_score_correction_bias,
395397
activation=activation,
396-
apply_router_weight_on_input=apply_router_weight_on_input)
398+
apply_router_weight_on_input=apply_router_weight_on_input,
399+
enable_eplb=enable_eplb,
400+
expert_load_view=expert_load_view,
401+
logical_to_physical_map=logical_to_physical_map,
402+
logical_replica_count=logical_replica_count
403+
)
397404

398405
def forward_cuda(
399406
self,
@@ -412,7 +419,16 @@ def forward_cuda(
412419
e_score_correction_bias: Optional[torch.Tensor] = None,
413420
apply_router_weight_on_input: bool = False,
414421
activation: str = "silu",
422+
enable_eplb: bool = False,
423+
expert_load_view: Optional[torch.Tensor] = None,
424+
logical_to_physical_map: Optional[torch.Tensor] = None,
425+
logical_replica_count: Optional[torch.Tensor] = None,
415426
) -> torch.Tensor:
427+
if enable_eplb:
428+
assert expert_load_view is not None
429+
assert logical_to_physical_map is not None
430+
assert logical_replica_count is not None
431+
assert isinstance(layer, FusedMoE)
416432

417433
topk_weights, topk_ids = FusedMoE.select_experts(
418434
hidden_states=x,
@@ -425,7 +441,13 @@ def forward_cuda(
425441
custom_routing_function=custom_routing_function,
426442
scoring_func=scoring_func,
427443
e_score_correction_bias=e_score_correction_bias,
428-
indices_type=self.topk_indices_dtype)
444+
indices_type=self.topk_indices_dtype,
445+
enable_eplb=enable_eplb,
446+
expert_map=expert_map,
447+
expert_load_view=expert_load_view,
448+
logical_to_physical_map=logical_to_physical_map,
449+
logical_replica_count=logical_replica_count
450+
)
429451

430452
if self.rocm_aiter_moe_enabled:
431453
return self.rocm_aiter_fused_experts(
@@ -730,16 +752,19 @@ def __init__(
730752
if self.enable_eplb:
731753
from vllm.model_executor.layers.quantization.fp8 import (
732754
Fp8MoEMethod)
733-
if not isinstance(quant_method, Fp8MoEMethod):
734-
# TODO: Add support for additional quantization methods.
735-
# The implementation for other quantization methods does not
736-
# contain essential differences, but the current quant API
737-
# design causes duplicated work when extending to new
738-
# quantization methods, so I'm leaving it for now.
739-
# If you plan to add support for more quantization methods,
740-
# please refer to the implementation in `Fp8MoEMethod`.
741-
raise NotImplementedError("EPLB is only supported for FP8 "
742-
"quantization for now.")
755+
756+
# TODO: Add support for additional quantization methods.
757+
SUPPORTED_MOE_QUANT_METHODS = {
758+
UnquantizedFusedMoEMethod,
759+
Fp8MoEMethod,
760+
}
761+
quant_method_type = type(quant_method)
762+
763+
if quant_method_type not in SUPPORTED_MOE_QUANT_METHODS:
764+
raise NotImplementedError(
765+
"EPLB is only supported the following quantization methods: "
766+
f"{[cls.__name__ for cls in SUPPORTED_MOE_QUANT_METHODS]}"
767+
)
743768

744769
moe_quant_params = {
745770
"num_experts": self.local_num_experts,

0 commit comments

Comments
 (0)