File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -623,6 +623,7 @@ def apply(
623623 scoring_func : str = "softmax" ,
624624 e_score_correction_bias : Optional [torch .Tensor ] = None ,
625625 is_prefill : bool = False ,
626+ enable_force_load_balance : bool = False ,
626627 ** kwargs ,
627628 ):
628629 # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -655,6 +656,11 @@ def apply(
655656 )
656657
657658 topk_weights = topk_weights .to (x .dtype )
659+ # this is a naive implementation for experts load balance so as
660+ # to avoid accumulating too much tokens on a single rank.
661+ # currently it is only activated when doing profile runs.
662+ if enable_force_load_balance :
663+ topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
658664
659665 if VLLM_ENABLE_MC2 and not is_prefill :
660666 return fused_experts_with_mc2 (
You can’t perform that action at this time.
0 commit comments