Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@
num_experts=160,
num_shared_experts=2,
top_k=6,
num_expert_groups=8,
num_limited_groups=3,
score_func="softmax",
route_norm=False,
route_scale=16.0,
score_before_experts=False,
),
n_expert_groups=8,
n_limited_groups=3,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
Expand All @@ -139,13 +139,13 @@
num_experts=256,
num_shared_experts=1,
top_k=8,
num_expert_groups=8,
num_limited_groups=4,
score_func="sigmoid",
route_norm=True,
route_scale=2.5,
score_before_experts=False,
),
n_expert_groups=8,
n_limited_groups=4,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
n_heads (int): Number of attention heads.
norm_eps (float): Epsilon value used for RMSNorm.
moe_args (MoEArgs): MoE configuration.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
Expand Down Expand Up @@ -66,9 +64,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs):

# MoE
moe_args: MoEArgs = field(default_factory=MoEArgs)
# TODO: node-limited routing is not supported yet
n_expert_groups: int = 1
n_limited_groups: int = 1

# Multi-Head Latent Attention (MLA)
q_lora_rank: int = 0
Expand Down
77 changes: 67 additions & 10 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ class MoEArgs:
route_scale: float = 1.0
score_before_experts: bool = True

# token-choice
# token-choice with optional node limited routing
top_k: int = 1
num_expert_groups: int | None = None # must be a divisor of num_experts
num_limited_groups: int | None = None
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

Expand Down Expand Up @@ -180,9 +182,17 @@ class TokenChoiceTopKRouter(nn.Module):
"""This class implements token-choice routing. In token-choice top-K routing, each token is
routed to top K experts based on the router scores.

Optionally supports node-limited (group-limited) routing where experts are divided into groups
(e.g., by node), and only num_limited_groups groups are considered before selecting top_k experts.
This reduces cross-node communication in distributed settings.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not true for NCCL native a2a, and only true with DeepEP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this depends how the communication is implemented. The node-limited routing itself works the same for both all-to-all communication, i.e., enforcing communication across less nodes.


Args:
dim (int): Dimension of input tokens.
num_experts (int): Number of experts in each moe layer.
num_expert_groups (int | None): Number of expert groups for node-limited routing. If None, standard
top-k routing is used. Must be a divisor of num_experts.
num_limited_groups (int | None): Number of groups to select in node-limited routing. Required when
num_expert_groups is set.
top_k (int): Number of experts each token will be routed to in token-choice routing.
score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores.
route_norm (bool): Whether to normalize the routing scores when using sigmoid.
Expand All @@ -193,6 +203,8 @@ def __init__(
self,
dim: int,
num_experts: int,
num_expert_groups: int | None,
num_limited_groups: int | None,
top_k: int,
score_func: Literal["softmax", "sigmoid"],
route_norm: bool,
Expand All @@ -202,6 +214,8 @@ def __init__(
super().__init__()
self.gate = nn.Linear(dim, num_experts, bias=False)
self.num_experts = num_experts
self.num_expert_groups = num_expert_groups
self.num_limited_groups = num_limited_groups
self.top_k = top_k
self.score_func = score_func
self.route_norm = route_norm
Expand All @@ -225,6 +239,47 @@ def _debug_force_load_balance_routing(
top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K]
return selected_experts_indices, top_scores

def _get_node_limited_routing_scores(
self,
scores_for_choice: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select num_limited_groups groups based on group scores,
and set expert scores in non-selected groups as -inf

Args:
scores_for_choice: Router scores with expert_bias (if any), shape (bs*slen, num_experts)

Returns:
scores_for_choice: shape (bs*slen, num_experts)
"""
if self.num_limited_groups is None:
raise ValueError(
"num_limited_groups must be set when num_expert_groups is set"
)
if self.num_experts % self.num_expert_groups != 0:
raise ValueError(
f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})"
)
experts_per_group = self.num_experts // self.num_expert_groups
if experts_per_group < 2:
raise ValueError(f"experts_per_group ({experts_per_group}) must be >= 2")
scores_grouped = scores_for_choice.view(
-1, self.num_expert_groups, experts_per_group
)
top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1)
group_scores = top2_scores_in_group.sum(dim=-1)
_, group_idx = torch.topk(
group_scores, k=self.num_limited_groups, dim=-1, sorted=False
)
group_mask = torch.ones_like(group_scores, dtype=torch.bool)
group_mask.scatter_(1, group_idx, False) # False = selected groups (keep)
# Mask out experts from non-selected groups
scores_for_choice = scores_grouped.masked_fill(
group_mask.unsqueeze(-1), float("-inf")
).view(-1, self.num_experts)

return scores_for_choice

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -254,18 +309,18 @@ def forward(
else:
raise NotImplementedError(f"Unknown score function {self.score_func}")

scores_for_choice = scores if expert_bias is None else scores + expert_bias
# Apply node-limited routing if configured
if self.num_expert_groups is not None:
scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice)
_, selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)

# top scores shape (bs*slen, top_k)
# NOTE: The expert_bias is only used for routing. The gating value
# top_scores is still derived from the original scores.
if expert_bias is not None:
_, selected_experts_indices = torch.topk(
scores + expert_bias, k=self.top_k, dim=1
)
top_scores = scores.gather(dim=1, index=selected_experts_indices)
else:
top_scores, selected_experts_indices = torch.topk(
scores, k=self.top_k, dim=1
)
top_scores = scores.gather(dim=1, index=selected_experts_indices)

# debug override: balanced round-robin routing
if self._debug_force_load_balance:
Expand Down Expand Up @@ -367,6 +422,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
self.router = TokenChoiceTopKRouter(
dim=dim,
num_experts=num_experts,
num_expert_groups=moe_args.num_expert_groups,
num_limited_groups=moe_args.num_limited_groups,
top_k=moe_args.top_k,
score_func=moe_args.score_func,
route_norm=moe_args.route_norm,
Expand Down
Loading