Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs):

# MoE
moe_args: MoEArgs = field(default_factory=MoEArgs)
# TODO: node-limited routing is not supported yet
n_expert_groups: int = 1
n_limited_groups: int = 1

Expand Down
84 changes: 82 additions & 2 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class MoEArgs:
route_scale: float = 1.0
score_before_experts: bool = True

# token-choice
# token-choice with node limited routing support
num_groups: int | None = None # must be a divisor of num_experts
top_k_group: int | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
top_k_group: int | None = None
top_k_groups: int | None = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the names coming from some repo? I saw that deepseek repo calls them n_expert_groups and n_limited_groups.

I think we can combine the convention and call them num_expert_groups and num_limited_groups. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's put these two fields below top_k which is a more "important" arg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great! I changed to num_expert_groups and num_limited_groups, which have clearer meaning from naming. The previous names are from huggingface's implementations.

top_k: int = 1
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3
Expand Down Expand Up @@ -180,9 +182,17 @@ class TokenChoiceTopKRouter(nn.Module):
"""This class implements token-choice routing. In token-choice top-K routing, each token is
routed to top K experts based on the router scores.

Optionally supports node-limited (group-limited) routing where experts are divided into groups
(e.g., by node), and only top_k_group groups are considered before selecting top_k experts.
This reduces cross-node communication in distributed settings.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not true for NCCL native a2a, and only true with DeepEP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this depends how the communication is implemented. The node-limited routing itself works the same for both all-to-all communication, i.e., enforcing communication across less nodes.


Args:
dim (int): Dimension of input tokens.
num_experts (int): Number of experts in each moe layer.
num_groups (int | None): Number of expert groups for node-limited routing. If None, standard
top-k routing is used. Must be a divisor of num_experts.
top_k_group (int | None): Number of groups to select in node-limited routing. Required when
num_groups is set.
top_k (int): Number of experts each token will be routed to in token-choice routing.
score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores.
route_norm (bool): Whether to normalize the routing scores when using sigmoid.
Expand All @@ -193,6 +203,8 @@ def __init__(
self,
dim: int,
num_experts: int,
num_groups: int | None,
top_k_group: int | None,
top_k: int,
score_func: Literal["softmax", "sigmoid"],
route_norm: bool,
Expand All @@ -202,6 +214,8 @@ def __init__(
super().__init__()
self.gate = nn.Linear(dim, num_experts, bias=False)
self.num_experts = num_experts
self.num_groups = num_groups
self.top_k_group = top_k_group
self.top_k = top_k
self.score_func = score_func
self.route_norm = route_norm
Expand All @@ -225,6 +239,60 @@ def _debug_force_load_balance_routing(
top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K]
return selected_experts_indices, top_scores

def _node_limited_routing(
self,
scores: torch.Tensor,
expert_bias: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Node-limited (group-limited) routing.

This method first selects top_k_group groups based on group scores,
then selects top_k experts from only those selected groups.

Args:
scores: Router scores after sigmoid or softmax, shape (bs*slen, num_experts)
expert_bias: Optional bias for load balancing, shape (num_experts,)

Returns:
tuple of (selected_experts_indices, top_scores)
- selected_experts_indices: shape (bs*slen, top_k)
- top_scores: shape (bs*slen, top_k)
"""
if expert_bias is not None:
scores_for_choice = scores + expert_bias
else:
scores_for_choice = scores

# Calculate group scores by taking top-2 experts per group and summing
experts_per_group = self.num_experts // self.num_groups
group_scores = (
scores_for_choice.view(-1, self.num_groups, experts_per_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.top_k_group, dim=-1, sorted=False)[
1
]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, self.num_groups, experts_per_group)
.reshape(-1, self.num_experts)
) # shape (bs*slen, num_experts)
# Mask out experts from non-selected groups
scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf")
)
# Select top_k experts from masked scores
selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)[1]
_, selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)

# Get actual scores (without bias) for the selected experts
top_scores = scores.gather(1, selected_experts_indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's unify this part with other two paths (even no expert bias we can do topk + gather)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! Merged into a unified _node_limited_routing method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also changed to a smaller mask to save memory.


return selected_experts_indices, top_scores

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -257,7 +325,17 @@ def forward(
# top scores shape (bs*slen, top_k)
# NOTE: The expert_bias is only used for routing. The gating value
# top_scores is still derived from the original scores.
if expert_bias is not None:
if self.num_groups is not None:
assert (
self.top_k_group is not None
), "top_k_group must be set when num_groups is set"
assert (
self.num_experts % self.num_groups == 0
), f"num_experts ({self.num_experts}) must be divisible by num_groups ({self.num_groups})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing assert, let's raise ValueError since they are more like error from users.

selected_experts_indices, top_scores = self._node_limited_routing(
scores, expert_bias
)
elif expert_bias is not None:
_, selected_experts_indices = torch.topk(
scores + expert_bias, k=self.top_k, dim=1
)
Expand Down Expand Up @@ -367,6 +445,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
self.router = TokenChoiceTopKRouter(
dim=dim,
num_experts=num_experts,
num_groups=moe_args.num_groups,
top_k_group=moe_args.top_k_group,
top_k=moe_args.top_k,
score_func=moe_args.score_func,
route_norm=moe_args.route_norm,
Expand Down
Loading