-
Notifications
You must be signed in to change notification settings - Fork 635
[MoE] Add node limited routing support #2111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,7 +26,9 @@ class MoEArgs: | |||||||||||||
| route_scale: float = 1.0 | ||||||||||||||
| score_before_experts: bool = True | ||||||||||||||
|
|
||||||||||||||
| # token-choice | ||||||||||||||
| # token-choice with node limited routing support | ||||||||||||||
| num_groups: int | None = None # must be a divisor of num_experts | ||||||||||||||
| top_k_group: int | None = None | ||||||||||||||
|
||||||||||||||
| top_k: int = 1 | ||||||||||||||
| use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation | ||||||||||||||
| load_balance_coeff: float | None = 1e-3 | ||||||||||||||
|
|
@@ -180,9 +182,17 @@ class TokenChoiceTopKRouter(nn.Module): | |||||||||||||
| """This class implements token-choice routing. In token-choice top-K routing, each token is | ||||||||||||||
| routed to top K experts based on the router scores. | ||||||||||||||
|
|
||||||||||||||
| Optionally supports node-limited (group-limited) routing where experts are divided into groups | ||||||||||||||
| (e.g., by node), and only top_k_group groups are considered before selecting top_k experts. | ||||||||||||||
| This reduces cross-node communication in distributed settings. | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not true for NCCL native a2a, and only true with DeepEP?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this depends how the communication is implemented. The node-limited routing itself works the same for both all-to-all communication, i.e., enforcing communication across less nodes. |
||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| dim (int): Dimension of input tokens. | ||||||||||||||
| num_experts (int): Number of experts in each moe layer. | ||||||||||||||
| num_groups (int | None): Number of expert groups for node-limited routing. If None, standard | ||||||||||||||
| top-k routing is used. Must be a divisor of num_experts. | ||||||||||||||
| top_k_group (int | None): Number of groups to select in node-limited routing. Required when | ||||||||||||||
| num_groups is set. | ||||||||||||||
| top_k (int): Number of experts each token will be routed to in token-choice routing. | ||||||||||||||
| score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores. | ||||||||||||||
| route_norm (bool): Whether to normalize the routing scores when using sigmoid. | ||||||||||||||
|
|
@@ -193,6 +203,8 @@ def __init__( | |||||||||||||
| self, | ||||||||||||||
| dim: int, | ||||||||||||||
| num_experts: int, | ||||||||||||||
| num_groups: int | None, | ||||||||||||||
| top_k_group: int | None, | ||||||||||||||
| top_k: int, | ||||||||||||||
| score_func: Literal["softmax", "sigmoid"], | ||||||||||||||
| route_norm: bool, | ||||||||||||||
|
|
@@ -202,6 +214,8 @@ def __init__( | |||||||||||||
| super().__init__() | ||||||||||||||
| self.gate = nn.Linear(dim, num_experts, bias=False) | ||||||||||||||
| self.num_experts = num_experts | ||||||||||||||
| self.num_groups = num_groups | ||||||||||||||
| self.top_k_group = top_k_group | ||||||||||||||
| self.top_k = top_k | ||||||||||||||
| self.score_func = score_func | ||||||||||||||
| self.route_norm = route_norm | ||||||||||||||
|
|
@@ -225,6 +239,60 @@ def _debug_force_load_balance_routing( | |||||||||||||
| top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K] | ||||||||||||||
| return selected_experts_indices, top_scores | ||||||||||||||
|
|
||||||||||||||
| def _node_limited_routing( | ||||||||||||||
| self, | ||||||||||||||
| scores: torch.Tensor, | ||||||||||||||
| expert_bias: torch.Tensor | None, | ||||||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||
| """Node-limited (group-limited) routing. | ||||||||||||||
|
|
||||||||||||||
| This method first selects top_k_group groups based on group scores, | ||||||||||||||
| then selects top_k experts from only those selected groups. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| scores: Router scores after sigmoid or softmax, shape (bs*slen, num_experts) | ||||||||||||||
| expert_bias: Optional bias for load balancing, shape (num_experts,) | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| tuple of (selected_experts_indices, top_scores) | ||||||||||||||
| - selected_experts_indices: shape (bs*slen, top_k) | ||||||||||||||
| - top_scores: shape (bs*slen, top_k) | ||||||||||||||
| """ | ||||||||||||||
| if expert_bias is not None: | ||||||||||||||
| scores_for_choice = scores + expert_bias | ||||||||||||||
| else: | ||||||||||||||
| scores_for_choice = scores | ||||||||||||||
|
|
||||||||||||||
| # Calculate group scores by taking top-2 experts per group and summing | ||||||||||||||
| experts_per_group = self.num_experts // self.num_groups | ||||||||||||||
| group_scores = ( | ||||||||||||||
| scores_for_choice.view(-1, self.num_groups, experts_per_group) | ||||||||||||||
| .topk(2, dim=-1)[0] | ||||||||||||||
| .sum(dim=-1) | ||||||||||||||
| ) | ||||||||||||||
| group_idx = torch.topk(group_scores, k=self.top_k_group, dim=-1, sorted=False)[ | ||||||||||||||
| 1 | ||||||||||||||
| ] | ||||||||||||||
| group_mask = torch.zeros_like(group_scores) | ||||||||||||||
| group_mask.scatter_(1, group_idx, 1) | ||||||||||||||
| score_mask = ( | ||||||||||||||
| group_mask.unsqueeze(-1) | ||||||||||||||
| .expand(-1, self.num_groups, experts_per_group) | ||||||||||||||
| .reshape(-1, self.num_experts) | ||||||||||||||
| ) # shape (bs*slen, num_experts) | ||||||||||||||
| # Mask out experts from non-selected groups | ||||||||||||||
| scores_for_choice = scores_for_choice.masked_fill( | ||||||||||||||
| ~score_mask.bool(), float("-inf") | ||||||||||||||
| ) | ||||||||||||||
| # Select top_k experts from masked scores | ||||||||||||||
| selected_experts_indices = torch.topk( | ||||||||||||||
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | ||||||||||||||
| )[1] | ||||||||||||||
|
||||||||||||||
| selected_experts_indices = torch.topk( | |
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | |
| )[1] | |
| _, selected_experts_indices = torch.topk( | |
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | |
| ) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's unify this part with other two paths (even no expert bias we can do topk + gather)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion! Merged into a unified _node_limited_routing method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also changed to a smaller mask to save memory.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing assert, let's raise ValueError since they are more like error from users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set the default for 671B model according to https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/configs/config_671B.json