Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@
num_experts=160,
num_shared_experts=2,
top_k=6,
num_expert_groups=8,
num_limited_groups=3,
score_func="softmax",
route_norm=False,
route_scale=16.0,
score_before_experts=False,
),
n_expert_groups=8,
n_limited_groups=3,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
Expand All @@ -139,13 +139,13 @@
num_experts=256,
num_shared_experts=1,
top_k=8,
num_expert_groups=8,
num_limited_groups=4,
score_func="sigmoid",
route_norm=True,
route_scale=2.5,
score_before_experts=False,
),
n_expert_groups=8,
n_limited_groups=4,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
n_heads (int): Number of attention heads.
norm_eps (float): Epsilon value used for RMSNorm.
moe_args (MoEArgs): MoE configuration.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
Expand Down Expand Up @@ -66,9 +64,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs):

# MoE
moe_args: MoEArgs = field(default_factory=MoEArgs)
# TODO: node-limited routing is not supported yet
n_expert_groups: int = 1
n_limited_groups: int = 1

# Multi-Head Latent Attention (MLA)
q_lora_rank: int = 0
Expand Down
96 changes: 84 additions & 12 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ class MoEArgs:
route_scale: float = 1.0
score_before_experts: bool = True

# token-choice
# token-choice with optional node limited routing
top_k: int = 1
num_expert_groups: int | None = None # must be a divisor of num_experts
num_limited_groups: int | None = None
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

Expand Down Expand Up @@ -180,9 +182,17 @@ class TokenChoiceTopKRouter(nn.Module):
"""This class implements token-choice routing. In token-choice top-K routing, each token is
routed to top K experts based on the router scores.

Optionally supports node-limited (group-limited) routing where experts are divided into groups
(e.g., by node), and only num_limited_groups groups are considered before selecting top_k experts.
This reduces cross-node communication in distributed settings.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not true for NCCL native a2a, and only true with DeepEP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this depends how the communication is implemented. The node-limited routing itself works the same for both all-to-all communication, i.e., enforcing communication across less nodes.


Args:
dim (int): Dimension of input tokens.
num_experts (int): Number of experts in each moe layer.
num_expert_groups (int | None): Number of expert groups for node-limited routing. If None, standard
top-k routing is used. Must be a divisor of num_experts.
num_limited_groups (int | None): Number of groups to select in node-limited routing. Required when
num_expert_groups is set.
top_k (int): Number of experts each token will be routed to in token-choice routing.
score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores.
route_norm (bool): Whether to normalize the routing scores when using sigmoid.
Expand All @@ -193,6 +203,8 @@ def __init__(
self,
dim: int,
num_experts: int,
num_expert_groups: int | None,
num_limited_groups: int | None,
top_k: int,
score_func: Literal["softmax", "sigmoid"],
route_norm: bool,
Expand All @@ -202,6 +214,8 @@ def __init__(
super().__init__()
self.gate = nn.Linear(dim, num_experts, bias=False)
self.num_experts = num_experts
self.num_expert_groups = num_expert_groups
self.num_limited_groups = num_limited_groups
self.top_k = top_k
self.score_func = score_func
self.route_norm = route_norm
Expand All @@ -225,6 +239,70 @@ def _debug_force_load_balance_routing(
top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K]
return selected_experts_indices, top_scores

def _node_limited_routing(
self,
scores: torch.Tensor,
expert_bias: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select top_k experts, optionally limiting to a subset of expert groups.

If num_expert_groups is set, applies node-limited routing:
1. Select num_limited_groups groups based on group scores
2. Select top_k experts only from those groups

If expert_bias is provided, it is added to scores for selection, but
the returned top_scores are always from the original (unbiased) scores.

Args:
scores: Router scores after sigmoid or softmax, shape (bs*slen, num_experts)
expert_bias: Optional bias for load balancing, shape (num_experts,)

Returns:
tuple of (selected_experts_indices, top_scores)
- selected_experts_indices: shape (bs*slen, top_k)
- top_scores: shape (bs*slen, top_k)
"""
scores_for_choice = scores if expert_bias is None else scores + expert_bias

# Apply node-limited routing mask if configured
if self.num_expert_groups is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should happen outside, i.e.

if node_limited:
  _get_node_limited_routing_scores()
else:
  ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion! Moved this condition check to the forward method and now this method will process scores_for_choice only if node-limited routing is configured, which better aligns with its name.

if self.num_limited_groups is None:
raise ValueError(
"num_limited_groups must be set when num_expert_groups is set"
)
if self.num_experts % self.num_expert_groups != 0:
raise ValueError(
f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})"
)
experts_per_group = self.num_experts // self.num_expert_groups
if experts_per_group < 2:
raise ValueError(
f"experts_per_group ({experts_per_group}) must be >= 2"
)
scores_grouped = scores_for_choice.view(
-1, self.num_expert_groups, experts_per_group
)
group_scores = scores_grouped.topk(2, dim=-1)[0].sum(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
group_scores = scores_grouped.topk(2, dim=-1)[0].sum(dim=-1)
top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1)
group_scores = top2_scores_in_group.sum(dim=-1)

wdyt?

group_idx = torch.topk(
group_scores, k=self.num_limited_groups, dim=-1, sorted=False
)[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
group_idx = torch.topk(
group_scores, k=self.num_limited_groups, dim=-1, sorted=False
)[1]
_, group_idx = torch.topk(
group_scores, k=self.num_limited_groups, dim=-1, sorted=False
)

For readability, it's easy to forgot why we need to slice topk()'s result

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorporated, thx!

group_mask = torch.ones_like(group_scores, dtype=torch.bool)
group_mask.scatter_(1, group_idx, False) # False = selected groups (keep)
# Mask out experts from non-selected groups
scores_for_choice = scores_grouped.masked_fill(
group_mask.unsqueeze(-1), float("-inf")
).view(-1, self.num_experts)

selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)[1]
_, selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)


# NOTE: The expert_bias is only used for routing. The gating value
# top_scores is still derived from the original scores.
top_scores = scores.gather(dim=1, index=selected_experts_indices)

return selected_experts_indices, top_scores
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should stay outside node_limited_routing method. If you worry about naming, you could call it something like _get_node_limited_routing_scores

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Moved these lines into the forward method.


def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -255,17 +333,9 @@ def forward(
raise NotImplementedError(f"Unknown score function {self.score_func}")

# top scores shape (bs*slen, top_k)
# NOTE: The expert_bias is only used for routing. The gating value
# top_scores is still derived from the original scores.
if expert_bias is not None:
_, selected_experts_indices = torch.topk(
scores + expert_bias, k=self.top_k, dim=1
)
top_scores = scores.gather(dim=1, index=selected_experts_indices)
else:
top_scores, selected_experts_indices = torch.topk(
scores, k=self.top_k, dim=1
)
selected_experts_indices, top_scores = self._node_limited_routing(
scores, expert_bias
)

# debug override: balanced round-robin routing
if self._debug_force_load_balance:
Expand Down Expand Up @@ -367,6 +437,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
self.router = TokenChoiceTopKRouter(
dim=dim,
num_experts=num_experts,
num_expert_groups=moe_args.num_expert_groups,
num_limited_groups=moe_args.num_limited_groups,
top_k=moe_args.top_k,
score_func=moe_args.score_func,
route_norm=moe_args.route_norm,
Expand Down
Loading