Skip to content

Commit 2a85428

Browse files
committed
small fix
Signed-off-by: MengqingCao <[email protected]>
1 parent 988ab44 commit 2a85428

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ def apply(
623623
scoring_func: str = "softmax",
624624
e_score_correction_bias: Optional[torch.Tensor] = None,
625625
is_prefill: bool = False,
626+
enable_force_load_balance: bool = False,
626627
**kwargs,
627628
):
628629
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -655,6 +656,11 @@ def apply(
655656
)
656657

657658
topk_weights = topk_weights.to(x.dtype)
659+
# this is a naive implementation for experts load balance so as
660+
# to avoid accumulating too much tokens on a single rank.
661+
# currently it is only activated when doing profile runs.
662+
if enable_force_load_balance:
663+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
658664

659665
if VLLM_ENABLE_MC2 and not is_prefill:
660666
return fused_experts_with_mc2(

0 commit comments

Comments
 (0)