-
Notifications
You must be signed in to change notification settings - Fork 681
[MoE] Add node limited routing support #2111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,8 +26,10 @@ class MoEArgs: | |||||||||||||
| route_scale: float = 1.0 | ||||||||||||||
| score_before_experts: bool = True | ||||||||||||||
|
|
||||||||||||||
| # token-choice | ||||||||||||||
| # token-choice with optional node limited routing | ||||||||||||||
| top_k: int = 1 | ||||||||||||||
| num_expert_groups: int | None = None # must be a divisor of num_experts | ||||||||||||||
| num_limited_groups: int | None = None | ||||||||||||||
| use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation | ||||||||||||||
| load_balance_coeff: float | None = 1e-3 | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -180,9 +182,17 @@ class TokenChoiceTopKRouter(nn.Module): | |||||||||||||
| """This class implements token-choice routing. In token-choice top-K routing, each token is | ||||||||||||||
| routed to top K experts based on the router scores. | ||||||||||||||
|
|
||||||||||||||
| Optionally supports node-limited (group-limited) routing where experts are divided into groups | ||||||||||||||
| (e.g., by node), and only num_limited_groups groups are considered before selecting top_k experts. | ||||||||||||||
| This reduces cross-node communication in distributed settings. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| dim (int): Dimension of input tokens. | ||||||||||||||
| num_experts (int): Number of experts in each moe layer. | ||||||||||||||
| num_expert_groups (int | None): Number of expert groups for node-limited routing. If None, standard | ||||||||||||||
| top-k routing is used. Must be a divisor of num_experts. | ||||||||||||||
| num_limited_groups (int | None): Number of groups to select in node-limited routing. Required when | ||||||||||||||
| num_expert_groups is set. | ||||||||||||||
| top_k (int): Number of experts each token will be routed to in token-choice routing. | ||||||||||||||
| score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores. | ||||||||||||||
| route_norm (bool): Whether to normalize the routing scores when using sigmoid. | ||||||||||||||
|
|
@@ -193,6 +203,8 @@ def __init__( | |||||||||||||
| self, | ||||||||||||||
| dim: int, | ||||||||||||||
| num_experts: int, | ||||||||||||||
| num_expert_groups: int | None, | ||||||||||||||
| num_limited_groups: int | None, | ||||||||||||||
| top_k: int, | ||||||||||||||
| score_func: Literal["softmax", "sigmoid"], | ||||||||||||||
| route_norm: bool, | ||||||||||||||
|
|
@@ -202,6 +214,8 @@ def __init__( | |||||||||||||
| super().__init__() | ||||||||||||||
| self.gate = nn.Linear(dim, num_experts, bias=False) | ||||||||||||||
| self.num_experts = num_experts | ||||||||||||||
| self.num_expert_groups = num_expert_groups | ||||||||||||||
| self.num_limited_groups = num_limited_groups | ||||||||||||||
| self.top_k = top_k | ||||||||||||||
| self.score_func = score_func | ||||||||||||||
| self.route_norm = route_norm | ||||||||||||||
|
|
@@ -225,6 +239,70 @@ def _debug_force_load_balance_routing( | |||||||||||||
| top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K] | ||||||||||||||
| return selected_experts_indices, top_scores | ||||||||||||||
|
|
||||||||||||||
| def _node_limited_routing( | ||||||||||||||
| self, | ||||||||||||||
| scores: torch.Tensor, | ||||||||||||||
| expert_bias: torch.Tensor | None, | ||||||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||
| """Select top_k experts, optionally limiting to a subset of expert groups. | ||||||||||||||
|
|
||||||||||||||
| If num_expert_groups is set, applies node-limited routing: | ||||||||||||||
| 1. Select num_limited_groups groups based on group scores | ||||||||||||||
| 2. Select top_k experts only from those groups | ||||||||||||||
|
|
||||||||||||||
| If expert_bias is provided, it is added to scores for selection, but | ||||||||||||||
| the returned top_scores are always from the original (unbiased) scores. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| scores: Router scores after sigmoid or softmax, shape (bs*slen, num_experts) | ||||||||||||||
| expert_bias: Optional bias for load balancing, shape (num_experts,) | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| tuple of (selected_experts_indices, top_scores) | ||||||||||||||
| - selected_experts_indices: shape (bs*slen, top_k) | ||||||||||||||
| - top_scores: shape (bs*slen, top_k) | ||||||||||||||
| """ | ||||||||||||||
| scores_for_choice = scores if expert_bias is None else scores + expert_bias | ||||||||||||||
|
|
||||||||||||||
| # Apply node-limited routing mask if configured | ||||||||||||||
| if self.num_expert_groups is not None: | ||||||||||||||
|
||||||||||||||
| if self.num_limited_groups is None: | ||||||||||||||
| raise ValueError( | ||||||||||||||
| "num_limited_groups must be set when num_expert_groups is set" | ||||||||||||||
| ) | ||||||||||||||
| if self.num_experts % self.num_expert_groups != 0: | ||||||||||||||
| raise ValueError( | ||||||||||||||
| f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})" | ||||||||||||||
| ) | ||||||||||||||
| experts_per_group = self.num_experts // self.num_expert_groups | ||||||||||||||
| if experts_per_group < 2: | ||||||||||||||
| raise ValueError( | ||||||||||||||
| f"experts_per_group ({experts_per_group}) must be >= 2" | ||||||||||||||
| ) | ||||||||||||||
| scores_grouped = scores_for_choice.view( | ||||||||||||||
| -1, self.num_expert_groups, experts_per_group | ||||||||||||||
| ) | ||||||||||||||
| group_scores = scores_grouped.topk(2, dim=-1)[0].sum(dim=-1) | ||||||||||||||
|
||||||||||||||
| group_scores = scores_grouped.topk(2, dim=-1)[0].sum(dim=-1) | |
| top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1) | |
| group_scores = top2_scores_in_group.sum(dim=-1) |
wdyt?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| group_idx = torch.topk( | |
| group_scores, k=self.num_limited_groups, dim=-1, sorted=False | |
| )[1] | |
| _, group_idx = torch.topk( | |
| group_scores, k=self.num_limited_groups, dim=-1, sorted=False | |
| ) |
For readability, it's easy to forgot why we need to slice topk()'s result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorporated, thx!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| selected_experts_indices = torch.topk( | |
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | |
| )[1] | |
| _, selected_experts_indices = torch.topk( | |
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | |
| ) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should stay outside node_limited_routing method. If you worry about naming, you could call it something like _get_node_limited_routing_scores
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense. Moved these lines into the forward method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not true for NCCL native a2a, and only true with DeepEP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this depends how the communication is implemented. The node-limited routing itself works the same for both all-to-all communication, i.e., enforcing communication across less nodes.