Skip to content

Execute proper gradient-tokens normalization in SFT Unit#1049

Open
diego-urgell wants to merge 1 commit intometa-pytorch:masterfrom
diego-urgell:export-D92223501
Open

Execute proper gradient-tokens normalization in SFT Unit#1049
diego-urgell wants to merge 1 commit intometa-pytorch:masterfrom
diego-urgell:export-D92223501

Conversation

@diego-urgell
Copy link
Contributor

Summary:

The problem

Currently we compute cross_entropy using sum reduction, and we don't execute any kind of normalization across ranks. For regular training this is correct because each sample contributes equally to the loss.

However, with variable-length sequences, longer sequences contribute more to the gradient than shorter ones. For us, this meant:

  • Across ranks, even if one rank processed more tokens than another one, their gradients are equally averaged
  • If using gradient accumulation, in addition to the inter-rank issue, we also give the same weight to samples across batches without knowing if they contribute equally.

This is a common problem found last year across mutliple frameworks:

Why is this Important

Some parallelization techniques like Tensor Parallelism tradeoff the DP dimension in favor of model parallelism. This means that we receive less samples per step. However, users will want to keep the effective batch size constant instead of doing an optimizer step with a quarter of the samples. This is when gradient accumulation becomes necessary.

But if we see widely different final losses when using gradient accumulation, they we can't have confidence that things are working fine.

Solution

The correct normalization is to sum per-token losses and divide by the total token count -- across all samples, batches, and DP ranks. This matches the approach used by torchtune's finetuning recipe: https://fburl.com/code/5wgp0vr4

We will:

  1. Skip loss normalization before backward
  2. Track total tokens and raw loss across all batches
  3. After accumulation, reduce token counts across DP ranks to get the global total
  4. Scale gradients by 1/total_tokens before the optimizer step

Differential Revision: D92223501

Summary:
## The problem

Currently we compute cross_entropy using sum reduction, and we don't execute any kind of normalization across ranks. For regular training this is correct because each sample contributes equally to the loss.

However, with variable-length sequences, longer sequences contribute more to the gradient than shorter ones. For us, this meant:
- Across ranks, even if one rank processed more tokens than another one, their gradients are equally averaged
- If using gradient accumulation, in addition to the inter-rank issue, we also give the same weight to samples across batches without knowing if they contribute equally.

This is a common problem found last year across mutliple frameworks:
- [HuggingFace](https://huggingface.co/blog/gradient_accumulation)
- [Unsloth](https://unsloth.ai/blog/gradient)

## Why is this Important 

Some parallelization techniques like Tensor Parallelism tradeoff the DP dimension in favor of model parallelism. This means that we receive less samples per step. However, users will want to keep the effective batch size constant instead of doing an optimizer step with a quarter of the samples. This is when gradient accumulation becomes necessary. 

But if we see widely different final losses when using gradient accumulation, they we can't have confidence that things are working fine. 

## Solution

The correct normalization is to sum per-token losses and divide by the total token count -- across all samples, batches, and DP ranks. This matches the approach used by torchtune's finetuning recipe: https://fburl.com/code/5wgp0vr4

We will:
  1. Skip loss normalization before backward
  2. Track total tokens and raw loss across all batches
  3. After accumulation, reduce token counts across DP ranks to get the global total
  4. Scale gradients by 1/total_tokens before the optimizer step

Differential Revision: D92223501
@meta-codesync
Copy link

meta-codesync bot commented Feb 5, 2026

@diego-urgell has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92223501.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant