Skip to content

AccumulateGrad Warning #21567

@guarin

Description

@guarin

Bug description

I get an AccumulateGrad warning when training with Fabric which doesn't happen with plain PyTorch. Not sure what is happening exactly, the current workaround is to disable the warning with torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False).

This is on Lightning 2.6.1

What version are you seeing the problem on?

master

Reproduced in studio

No response

How to reproduce the bug

Fabric

from __future__ import annotations

import torch
from lightning_fabric import Fabric
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from torch.utils.data import DataLoader

import warnings
warnings.filterwarnings(message=".*AccumulateGrad.*", action="error")


def main() -> None:
    # Configuration
    batch_size = 32
    accumulation_steps = 2
    max_steps = 20

    # Initialize Fabric
    fabric = Fabric(accelerator="auto", devices="auto")
    fabric.launch()

    # Create model and optimizer
    model = BoringModel()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

    # Setup with Fabric
    model, optimizer = fabric.setup(model, optimizer)

    # Create dataset
    dummy_dataset = RandomDataset(size=32, length=100)
    train_loader = DataLoader(dataset=dummy_dataset, batch_size=batch_size, shuffle=True)
    train_loader = fabric.setup_dataloaders(train_loader)

    # Training loop
    model.train()
    global_step = 0

    while global_step < max_steps:
        for batch in train_loader:
            if global_step >= max_steps:
                break
            
            print(f"Step: {global_step}")
            is_accumulating = (global_step + 1) % accumulation_steps != 0
            with fabric.no_backward_sync(model, enabled=is_accumulating):
                output = model(batch)
                loss = output.sum() / accumulation_steps
                fabric.backward(loss)

            if not is_accumulating:
                optimizer.step()
                optimizer.zero_grad()
            
            global_step += 1


if __name__ == "__main__":
    main()


PyTorch

from __future__ import annotations

import torch
import torch.distributed as dist
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel

import warnings
warnings.filterwarnings(message=".*AccumulateGrad.*", action="error")


def main() -> None:
    # Initialize DDP
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # Determine device based on rank
    device = torch.device(device=f"cuda:{rank}")
    torch.cuda.set_device(device=device)

    # Configuration
    batch_size = 32
    accumulation_steps = 2
    max_steps = 20

    # Create model and optimizer
    model = BoringModel().to(device=device)
    model = DistributedDataParallel(model, device_ids=[rank])
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

    # Create dataset
    dummy_dataset = RandomDataset(size=32, length=100)
    train_sampler = DistributedSampler(
        dataset=dummy_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
    )
    train_loader = DataLoader(dataset=dummy_dataset, batch_size=batch_size, sampler=train_sampler)

    # Training loop
    model.train()
    global_step = 0

    while global_step < max_steps:
        train_sampler.set_epoch(epoch=global_step // len(train_loader))
        for batch in train_loader:
            if global_step >= max_steps:
                break
            
            print(f"Step: {global_step}")
            batch = batch.to(device=device)

            # Forward pass
            output = model(batch)
            loss = output.sum() / accumulation_steps

            # Backward pass
            loss.backward()

            # Update weights after accumulating gradients
            if (global_step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            global_step += 1

    # Cleanup
    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Error messages and logs

# Error messages and logs here please

Environment

You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

Step: 0
Step: 0
[rank1]: Traceback (most recent call last):
[rank1]:   File "run_test_grad_acc.py", line 58, in <module>
[rank1]:     main()
[rank1]:   File "run_test_grad_acc.py", line 48, in main
[rank1]:     fabric.backward(loss)
[rank1]:   File ".venv/lib/python3.10/site-packages/lightning_fabric/fabric.py", line 523, in backward
[rank1]:     self._strategy.backward(tensor, module, *args, **kwargs)
[rank1]:   File ".venv/lib/python3.10/site-packages/lightning_fabric/strategies/strategy.py", line 192, in backward
[rank1]:     self.precision.backward(tensor, module, *args, **kwargs)
[rank1]:   File ".venv/lib/python3.10/site-packages/lightning_fabric/plugins/precision/precision.py", line 107, in backward
[rank1]:     tensor.backward(*args, **kwargs)
[rank1]:   File ".venv/lib/python3.10/site-packages/torch/_tensor.py", line 630, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File ".venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File ".venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]: UserWarning: The AccumulateGrad node's stream does not match the stream of the node that produced the incoming gradient. This may incur unnecessary synchronization and break CUDA graph capture if the AccumulateGrad node's stream is the default stream. This mismatch is caused by an AccumulateGrad node created prior to the current iteration being kept alive. This can happen if the autograd graph is still being kept alive by tensors such as the loss, or if you are using DDP, which will stash a reference to the node. To resolve the mismatch, delete all references to the autograd graph or ensure that DDP initialization is performed under the same stream as subsequent forwards. If the mismatch is intentional, you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this warning. (Triggered internally at /pytorch/torch/csrc/autograd/input_buffer.cpp:240.)
[rank0]: Traceback (most recent call last):
[rank0]:   File "run_test_grad_acc.py", line 58, in <module>
[rank0]:     main()
[rank0]:   File "run_test_grad_acc.py", line 48, in main
[rank0]:     fabric.backward(loss)
[rank0]:   File ".venv/lib/python3.10/site-packages/lightning_fabric/fabric.py", line 523, in backward
[rank0]:     self._strategy.backward(tensor, module, *args, **kwargs)
[rank0]:   File ".venv/lib/python3.10/site-packages/lightning_fabric/strategies/strategy.py", line 192, in backward
[rank0]:     self.precision.backward(tensor, module, *args, **kwargs)
[rank0]:   File ".venv/lib/python3.10/site-packages/lightning_fabric/plugins/precision/precision.py", line 107, in backward
[rank0]:     tensor.backward(*args, **kwargs)
[rank0]:   File ".venv/lib/python3.10/site-packages/torch/_tensor.py", line 630, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File ".venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File ".venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: UserWarning: The AccumulateGrad node's stream does not match the stream of the node that produced the incoming gradient. This may incur unnecessary synchronization and break CUDA graph capture if the AccumulateGrad node's stream is the default stream. This mismatch is caused by an AccumulateGrad node created prior to the current iteration being kept alive. This can happen if the autograd graph is still being kept alive by tensors such as the loss, or if you are using DDP, which will stash a reference to the node. To resolve the mismatch, delete all references to the autograd graph or ensure that DDP initialization is performed under the same stream as subsequent forwards. If the mismatch is intentional, you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this warning. (Triggered internally at /pytorch/torch/csrc/autograd/input_buffer.cpp:240.)

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions