@@ -481,8 +481,16 @@ def forward_cpu(
481481 e_score_correction_bias : Optional [torch .Tensor ] = None ,
482482 apply_router_weight_on_input : bool = False ,
483483 activation : str = "silu" ,
484- ** kwargs ,
484+ enable_eplb : bool = False ,
485+ expert_load_view : Optional [torch .Tensor ] = None ,
486+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
487+ logical_replica_count : Optional [torch .Tensor ] = None ,
485488 ):
489+ if enable_eplb is not False or expert_load_view is not None or \
490+ logical_to_physical_map is not None or \
491+ logical_replica_count is not None :
492+ raise NotImplementedError ("Expert load balancing is not supported "
493+ "for CPU." )
486494 return layer .cpu_fused_moe (
487495 layer ,
488496 x ,
@@ -518,6 +526,10 @@ def forward_tpu(
518526 e_score_correction_bias : Optional [torch .Tensor ] = None ,
519527 apply_router_weight_on_input : bool = False ,
520528 activation : str = "silu" ,
529+ enable_eplb : bool = False ,
530+ expert_load_view : Optional [torch .Tensor ] = None ,
531+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
532+ logical_replica_count : Optional [torch .Tensor ] = None ,
521533 ) -> torch .Tensor :
522534 assert not use_grouped_topk
523535 assert num_expert_group is None
@@ -531,6 +543,11 @@ def forward_tpu(
531543 raise NotImplementedError (
532544 "Expert score correction bias is not supported for TPU." )
533545 assert activation == "silu" , f"{ activation } is not supported for TPU."
546+ if enable_eplb is not False or expert_load_view is not None or \
547+ logical_to_physical_map is not None or \
548+ logical_replica_count is not None :
549+ raise NotImplementedError ("Expert load balancing is not supported "
550+ "for TPU." )
534551 return fused_moe_pallas (hidden_states = x ,
535552 w1 = layer .w13_weight ,
536553 w2 = layer .w2_weight ,
0 commit comments