-
Notifications
You must be signed in to change notification settings - Fork 633
[MoE] Add node limited routing support #2111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torchtitan/models/moe/moe.py
Outdated
| # token-choice | ||
| # token-choice with node limited routing support | ||
| num_groups: int | None = None # must be a divisor of num_experts | ||
| top_k_group: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| top_k_group: int | None = None | |
| top_k_groups: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set the default for 671B model according to https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/configs/config_671B.json
torchtitan/models/moe/moe.py
Outdated
| selected_experts_indices = torch.topk( | ||
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | ||
| )[1] | ||
| # Get actual scores (without bias) for the selected experts | ||
| top_scores = scores.gather(1, selected_experts_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's unify this part with other two paths (even no expert bias we can do topk + gather)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion! Merged into a unified _node_limited_routing method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also changed to a smaller mask to save memory.
torchtitan/models/moe/moe.py
Outdated
| # token-choice | ||
| # token-choice with node limited routing support | ||
| num_groups: int | None = None # must be a divisor of num_experts | ||
| top_k_group: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set the default for 671B model according to https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/configs/config_671B.json
torchtitan/models/moe/moe.py
Outdated
| num_groups: int | None = None # must be a divisor of num_experts | ||
| top_k_group: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the names coming from some repo? I saw that deepseek repo calls them n_expert_groups and n_limited_groups.
I think we can combine the convention and call them num_expert_groups and num_limited_groups. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's put these two fields below top_k which is a more "important" arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great! I changed to num_expert_groups and num_limited_groups, which have clearer meaning from naming. The previous names are from huggingface's implementations.
torchtitan/models/moe/moe.py
Outdated
| assert ( | ||
| self.top_k_group is not None | ||
| ), "top_k_group must be set when num_groups is set" | ||
| assert ( | ||
| self.num_experts % self.num_groups == 0 | ||
| ), f"num_experts ({self.num_experts}) must be divisible by num_groups ({self.num_groups})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing assert, let's raise ValueError since they are more like error from users.
| selected_experts_indices = torch.topk( | ||
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | ||
| )[1] | ||
|
|
||
| # NOTE: The expert_bias is only used for routing. The gating value | ||
| # top_scores is still derived from the original scores. | ||
| top_scores = scores.gather(dim=1, index=selected_experts_indices) | ||
|
|
||
| return selected_experts_indices, top_scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should stay outside node_limited_routing method. If you worry about naming, you could call it something like _get_node_limited_routing_scores
| Optionally supports node-limited (group-limited) routing where experts are divided into groups | ||
| (e.g., by node), and only num_limited_groups groups are considered before selecting top_k experts. | ||
| This reduces cross-node communication in distributed settings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not true for NCCL native a2a, and only true with DeepEP?
| group_idx = torch.topk( | ||
| group_scores, k=self.num_limited_groups, dim=-1, sorted=False | ||
| )[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| group_idx = torch.topk( | |
| group_scores, k=self.num_limited_groups, dim=-1, sorted=False | |
| )[1] | |
| _, group_idx = torch.topk( | |
| group_scores, k=self.num_limited_groups, dim=-1, sorted=False | |
| ) |
For readability, it's easy to forgot why we need to slice topk()'s result
| selected_experts_indices = torch.topk( | ||
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | ||
| )[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| selected_experts_indices = torch.topk( | |
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | |
| )[1] | |
| _, selected_experts_indices = torch.topk( | |
| scores_for_choice, k=self.top_k, dim=-1, sorted=False | |
| ) |
| scores_grouped = scores_for_choice.view( | ||
| -1, self.num_expert_groups, experts_per_group | ||
| ) | ||
| group_scores = scores_grouped.topk(2, dim=-1)[0].sum(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| group_scores = scores_grouped.topk(2, dim=-1)[0].sum(dim=-1) | |
| top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1) | |
| group_scores = top2_scores_in_group.sum(dim=-1) |
wdyt?
As titled, added node-limited routing support via two-layer routing. First, group experts into
num_groupsgroups, and experts in the same group should reside on the same node to utilize fast intra-node communication. Second, pick thetop_k_groupby the top 2 expert scores' sum in each group. Third, picktop_kexperts within the selectedtop_k_groups.Reference: https://github.com/huggingface/transformers/blob/4c9fde2a2a3aece0bcf1be93f696e88297da9397/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py#L212
Test on one node using DeepSeek V3 debug model with MoE arguments

num_experts=8, num_shared_experts=2, num_groups=4, top_k_group=2, top_k=3.