-
Notifications
You must be signed in to change notification settings - Fork 681
[moe] brings batch/sequence-wise load balance loss #2061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
1c5ddd5
77dd533
1bdc48b
2830bcb
1df2412
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,3 +78,21 @@ def build_mse_loss(job_config: JobConfig, **kwargs): | |
| logger.info("Compiling the loss function with torch.compile") | ||
| loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) | ||
| return loss_fn | ||
|
|
||
|
|
||
| def moe_loss( | ||
| pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, | ||
| labels: torch.Tensor, | ||
| loss_fn: LossFunction, | ||
| ) -> torch.Tensor: | ||
| """Sequence-wise auxiliary load balance loss function for MoE | ||
| model training. | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| if isinstance(pred, tuple): | ||
| pred, load_balance_loss = pred | ||
| loss = loss_fn(pred, labels) | ||
| # USE STE to make the magnitude of loss remain the same | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| loss = loss + (load_balance_loss - load_balance_loss.detach()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code looks too hacky. Curious why we don't want to log the full loss?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if one needs to log load_balance_loss, a helpful way is to log it via moe's optimizer pre-step hook. (where we log everything about moe, e.g. bias, experts usage, entropy, lb loss etc). And we dont need to hack the return of "loss" (which is a single value/tensor for dense/moe/diffusion model training). and for people who want to run ablation study, the "loss" is the "clean" CE loss for training/validation |
||
| else: | ||
| loss = loss_fn(pred, labels) | ||
| return loss | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -97,6 +97,18 @@ class Metrics: | |
| """Whether to log metrics to Weights & Biases""" | ||
|
|
||
|
|
||
| @dataclass | ||
| class ExtraLosses: | ||
|
||
| load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" | ||
| """Type of load balance loss to use""" | ||
|
|
||
| load_balance_loss_weight: float = 0 | ||
| """Weight of load balance loss""" | ||
|
|
||
| load_balance_coeff: float | None = 1e-3 | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Coefficient of bias update for aux-loss-free load balancing""" | ||
|
|
||
|
|
||
| @dataclass | ||
| class Model: | ||
| name: str = "llama3" | ||
|
|
@@ -130,6 +142,9 @@ class Model: | |
| converters have been applied. | ||
| """ | ||
|
|
||
| extra_losses: ExtraLosses = field(default_factory=ExtraLosses) | ||
| """Extra losses to use""" | ||
|
|
||
|
|
||
| @dataclass | ||
| class Optimizer: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -309,6 +309,7 @@ def forward( | |
| self, | ||
| x: torch.Tensor, | ||
| freqs_cis: torch.Tensor, | ||
| accumulated_load_balance_loss: torch.Tensor, | ||
| attention_masks: AttentionMasksType | None, | ||
| ): | ||
| """ | ||
|
|
@@ -323,10 +324,15 @@ def forward( | |
| """ | ||
| x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) | ||
| if self.moe_enabled: | ||
| x = x + self.moe(self.ffn_norm(x)) | ||
| ffn_moe_output, load_balance_loss = self.moe(self.ffn_norm(x)) | ||
| accumulated_load_balance_loss = ( | ||
| accumulated_load_balance_loss + load_balance_loss | ||
| ) | ||
| else: | ||
| x = x + self.feed_forward(self.ffn_norm(x)) | ||
| return x | ||
| ffn_moe_output = self.feed_forward(self.ffn_norm(x)) | ||
|
|
||
| x = x + ffn_moe_output | ||
| return x, accumulated_load_balance_loss | ||
|
|
||
| def init_weights(self, buffer_device: torch.device): | ||
| for norm in (self.attention_norm, self.ffn_norm): | ||
|
|
@@ -410,6 +416,7 @@ def get_attention_masks( | |
| def forward( | ||
| self, | ||
| tokens: torch.Tensor, | ||
| accumulated_load_balance_loss: torch.Tensor | None = None, | ||
| attention_masks: AttentionMasksType | None = None, | ||
| ): | ||
| """ | ||
|
|
@@ -427,8 +434,16 @@ def forward( | |
|
|
||
| h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens | ||
|
|
||
| accumulated_load_balance_loss = ( | ||
| torch.zeros((), device=h.device, dtype=torch.float32) | ||
| if accumulated_load_balance_loss is None | ||
| else accumulated_load_balance_loss | ||
| ) | ||
|
|
||
| for layer in self.layers.values(): | ||
| h = layer(h, self.freqs_cis, attention_masks) | ||
| h, accumulated_load_balance_loss = layer( | ||
| h, self.freqs_cis, accumulated_load_balance_loss, attention_masks | ||
| ) | ||
|
Comment on lines
+444
to
+446
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, passing this per-layer loss along all the way to the final output sounds unnecessary. It sounds correct but is causing quite intrusive changes to the entire model code. Putting PP aside, is it true that we can also achieve this via a buffer in each MoE module, similar to the expert bias? Specifically, putting the per-layer loss in a buffer, and in the loss function fetch the value and add them to the final loss. Is this similar to the idea in #1979 when you say
In what way it breaks compile? cc @bdhirsh @xmfan With PP, I'm not sure what's the best way to do it. @H-Huang any suggestions?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd be interested in what about caching auxilliary loss breaks compile as well. This is probably not representative, but I have a basic working example here of using torch.compile to compile each layer of a transformer model here where I also compute and cache an extra loss onto each layer: https://gist.github.com/bdhirsh/2b59611d3070354af3f6364d9becaa08
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx both.
If we register a buffer named "aux_loss", and use it saved aux loss value, and access it and sum it to the final loss. compiler will broken.
Here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I mean, the purpose of passing (accumulated) aux-Loss in forward is to deal with PP. w/o pp, we have lots of clean solutions, including those you proposed. For pp, if we put each block's aux loss along with the block (via buffer or whatever methods). At the backward the last stage will not be able to capture the aux-loss-of-block-i. For that we need either add the backward hook, that we manually hacking the back ward gradinet (which does not work well with compile as i have tested). Or we need manually add a communcation to gather all stages's aux-loss into last stage. (and we also need to think about micrio-batch things that we need some queue for aux-loss-of-block-i).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed that this is non-trivial. Let me think about it and get back to you. |
||
| h = self.norm(h) if self.norm is not None else h | ||
| output = self.output(h) if self.output is not None else h | ||
| return output | ||
| return output, accumulated_load_balance_loss | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,8 @@ class MoEArgs: | |
| top_k: int = 1 | ||
| use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation | ||
| load_balance_coeff: float | None = 1e-3 | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| load_balance_loss_weight: float = 0 | ||
| load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" | ||
| _debug_force_load_balance: bool = False | ||
| # if True, we force each experts get same amount of token via round-robin | ||
|
|
||
|
|
@@ -287,7 +288,7 @@ def forward( | |
| max=self.num_experts, | ||
| ) | ||
|
|
||
| return top_scores, selected_experts_indices, num_tokens_per_expert | ||
| return top_scores, scores, selected_experts_indices, num_tokens_per_expert | ||
|
|
||
| def init_weights(self, init_std: float): | ||
| nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) | ||
|
|
@@ -359,6 +360,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): | |
| super().__init__() | ||
|
|
||
| num_experts = moe_args.num_experts | ||
| self.topk = moe_args.top_k | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.num_experts = num_experts | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.experts = GroupedExperts( | ||
| dim=dim, | ||
| hidden_dim=hidden_dim, | ||
|
|
@@ -386,6 +389,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): | |
| # NOTE: tokens_per_expert is accumulated in the model forward pass. | ||
| # expert_bias is updated outside the model in an optimizer step pre hook | ||
| # to work with gradient accumulation. | ||
| self.load_balance_loss_weight = moe_args.load_balance_loss_weight | ||
| self.load_balance_loss_type = moe_args.load_balance_loss_type | ||
| self.load_balance_coeff = moe_args.load_balance_coeff | ||
| if self.load_balance_coeff is not None: | ||
| assert self.load_balance_coeff > 0.0 | ||
|
|
@@ -418,6 +423,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # num_tokens_per_expert shape (num_experts,) | ||
| ( | ||
| top_scores, | ||
| scores, | ||
| selected_experts_indices, | ||
| num_tokens_per_expert, | ||
| ) = self.router(x, self.expert_bias) | ||
|
|
@@ -430,6 +436,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| with torch.no_grad(): | ||
| self.tokens_per_expert.add_(num_tokens_per_expert) | ||
|
|
||
| if self.training: | ||
| if self.load_balance_loss_type == "sequence_wise": | ||
| load_balance_loss = MoE.sequence_wise_aux_loss( | ||
| scores, | ||
| selected_experts_indices.long(), | ||
| bs, | ||
| slen, | ||
| self.topk, | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.load_balance_loss_weight, | ||
| ) | ||
| elif self.load_balance_loss_type == "batch_wise": | ||
| load_balance_loss = MoE.batch_wise_aux_loss( | ||
| scores, | ||
| num_tokens_per_expert, | ||
| self.topk, | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.load_balance_loss_weight, | ||
| ) | ||
| else: | ||
| load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype) | ||
|
||
|
|
||
| # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) | ||
| # num_tokens_per_expert shape (num_experts,) | ||
| # NOTE: the reason we need to compute num_tokens_per_expert again is: | ||
|
|
@@ -479,7 +505,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| dim=0, index=token_indices_experts_sorted, src=routed_output | ||
| ) | ||
| out = out.reshape(bs, slen, dim) | ||
| return out | ||
|
|
||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return out, load_balance_loss | ||
|
|
||
| def init_weights( | ||
| self, | ||
|
|
@@ -499,3 +526,94 @@ def init_weights( | |
| self.expert_bias = torch.zeros( | ||
| self.experts.num_experts, dtype=torch.float32 | ||
| ) | ||
|
|
||
| @staticmethod | ||
| @torch.compile(fullgraph=True) | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def sequence_wise_aux_loss( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC this returns a loss independent of DP, similar to the cross-entropy loss, in that each DP rank computes its own aux loss and do backward which eventually gets gradients reduced across DP ranks by FSDP. In general, is there a way to verify the correctness of the implementation? |
||
| scores: torch.Tensor, # Shape: (B*S, N) - Raw Sigmoid Affinities (s_{i,t}) | ||
| indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices | ||
| B: int, # Batch size | ||
| S: int, # Sequence length (T in the paper) | ||
| top_k: int, # K_r | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| aux_loss_alpha: float, # Alpha | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> torch.Tensor: | ||
| """ | ||
| Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20). | ||
|
|
||
| Args: | ||
| scores: The dense affinity scores (s_{i,t}) for routed experts. | ||
| Should be the output of Sigmoid, shape (B*S, N). | ||
| indices: The top-k selected expert indices. Shape (B*S, K). | ||
| """ | ||
| if aux_loss_alpha <= 0: | ||
| return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) | ||
|
|
||
| # N_r: Total number of routed experts | ||
| N = scores.size(-1) | ||
|
|
||
| # 1. Reshape inputs to handle each sequence separately: (B, S, N) | ||
| # This ensures we calculate P_i and f_i per sequence (Eq 20 & 18). | ||
| scores_per_seq = scores.view(B, S, N) | ||
| indices_per_seq = indices.view(B, S, top_k) | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t} | ||
| # DeepSeek-V3 uses Sigmoid, so scores don't sum to 1. | ||
| # Eq 19 explicitly requires dividing by the sum of all affinities. | ||
| # denominator shape: (B, S, 1) | ||
| denominator = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20 | ||
| probs_per_seq = scores_per_seq / denominator # This is s'_{i,t} | ||
|
|
||
| # 3. Eq 20: Calculate P_i (Average probability per expert for each sequence) | ||
| # P_i = (1/T) * sum_{t=1}^T (s'_{i,t}) | ||
| # We average over the Sequence dimension (dim=1). | ||
| # P_i shape: (B, N) | ||
| P_i = probs_per_seq.mean(dim=1) | ||
|
|
||
| # 4. Eq 18: Calculate f_i (Fraction of tokens selecting expert i per sequence) | ||
| # f_i = (N / (K * T)) * count_i | ||
|
|
||
| # Flatten the top-k dimension to count hits per sequence: (B, S*K) | ||
| flat_indices_per_seq = indices_per_seq.view(B, -1) | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype) | ||
| src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype) | ||
| selection_counts.scatter_add_(1, flat_indices_per_seq, src) | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Calculate f_i for each sequence, T (tokens in sequence) is S | ||
| f_i = selection_counts * (N / (top_k * S)) | ||
|
|
||
| # 5. Eq 17: Calculate Balance Loss | ||
| loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return loss_per_seq.mean() | ||
|
|
||
| @staticmethod | ||
| @torch.compile(fullgraph=True) | ||
| def batch_wise_aux_loss( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this batch one? For DSv3 it seems unnecessary. For Qwen3 it seems insufficient, which adopts a global batch load balancing but here it looks local / microbatch (see e.g. https://qwenlm.github.io/blog/global-load-balance/) For simplicity let's start with seq_wise?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, ideally we can make it (type of aux-loss) configurable I am curious if people deal with a special case of document segmentation (for models trained with document mask attention, it's literally not "one" sequence on attention's side) |
||
| scores: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor, | ||
| top_k: int, | ||
| aux_loss_alpha: float, | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> torch.Tensor: | ||
| """ | ||
| Computes Batch-Wise Auxiliary Loss. | ||
| Args: | ||
| scores: Dense probabilities (BS, N). | ||
| num_tokens_per_expert: Token counts (N). | ||
| top_k: Number of experts selected per token. | ||
| aux_loss_alpha: Scaling factor for the loss. | ||
| """ | ||
| if aux_loss_alpha <= 0: | ||
| return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) | ||
|
|
||
| # Total number of routed experts (N) | ||
| N = scores.size(1) | ||
| # Total number of tokens (T = BS * S) | ||
| T = scores.size(0) | ||
|
|
||
| P_i = scores.mean(dim=0) | ||
|
|
||
| f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T)) | ||
|
|
||
| loss = (f_i * P_i).sum() * aux_loss_alpha | ||
rakkit marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return loss | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import functools | ||
| import importlib | ||
| import os | ||
| import time | ||
|
|
@@ -18,7 +19,7 @@ | |
| from torchtitan.components.checkpoint import CheckpointManager | ||
| from torchtitan.components.dataloader import DataloaderExhaustedError | ||
| from torchtitan.components.ft import FTManager, maybe_semi_sync_training | ||
| from torchtitan.components.loss import rescale_accumulated_loss | ||
| from torchtitan.components.loss import moe_loss, rescale_accumulated_loss | ||
| from torchtitan.components.metrics import ( | ||
| build_metrics_processor, | ||
| ensure_pp_loss_visible, | ||
|
|
@@ -184,6 +185,11 @@ def __init__(self, job_config: JobConfig): | |
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( | ||
|
||
| moe_loss, | ||
| loss_fn=self.loss_fn, | ||
| ) | ||
|
|
||
| # verify batch sizes | ||
| global_batch_size = job_config.training.global_batch_size | ||
| if global_batch_size < 0: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.