-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
I get an AccumulateGrad warning when training with Fabric which doesn't happen with plain PyTorch. Not sure what is happening exactly, the current workaround is to disable the warning with torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False).
This is on Lightning 2.6.1
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
Fabric
from __future__ import annotations
import torch
from lightning_fabric import Fabric
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings(message=".*AccumulateGrad.*", action="error")
def main() -> None:
# Configuration
batch_size = 32
accumulation_steps = 2
max_steps = 20
# Initialize Fabric
fabric = Fabric(accelerator="auto", devices="auto")
fabric.launch()
# Create model and optimizer
model = BoringModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
# Setup with Fabric
model, optimizer = fabric.setup(model, optimizer)
# Create dataset
dummy_dataset = RandomDataset(size=32, length=100)
train_loader = DataLoader(dataset=dummy_dataset, batch_size=batch_size, shuffle=True)
train_loader = fabric.setup_dataloaders(train_loader)
# Training loop
model.train()
global_step = 0
while global_step < max_steps:
for batch in train_loader:
if global_step >= max_steps:
break
print(f"Step: {global_step}")
is_accumulating = (global_step + 1) % accumulation_steps != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
output = model(batch)
loss = output.sum() / accumulation_steps
fabric.backward(loss)
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
global_step += 1
if __name__ == "__main__":
main()
PyTorch
from __future__ import annotations
import torch
import torch.distributed as dist
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel
import warnings
warnings.filterwarnings(message=".*AccumulateGrad.*", action="error")
def main() -> None:
# Initialize DDP
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
# Determine device based on rank
device = torch.device(device=f"cuda:{rank}")
torch.cuda.set_device(device=device)
# Configuration
batch_size = 32
accumulation_steps = 2
max_steps = 20
# Create model and optimizer
model = BoringModel().to(device=device)
model = DistributedDataParallel(model, device_ids=[rank])
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
# Create dataset
dummy_dataset = RandomDataset(size=32, length=100)
train_sampler = DistributedSampler(
dataset=dummy_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
train_loader = DataLoader(dataset=dummy_dataset, batch_size=batch_size, sampler=train_sampler)
# Training loop
model.train()
global_step = 0
while global_step < max_steps:
train_sampler.set_epoch(epoch=global_step // len(train_loader))
for batch in train_loader:
if global_step >= max_steps:
break
print(f"Step: {global_step}")
batch = batch.to(device=device)
# Forward pass
output = model(batch)
loss = output.sum() / accumulation_steps
# Backward pass
loss.backward()
# Update weights after accumulating gradients
if (global_step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
global_step += 1
# Cleanup
dist.destroy_process_group()
if __name__ == "__main__":
main()Error messages and logs
# Error messages and logs here please
Environment
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
Step: 0
Step: 0
[rank1]: Traceback (most recent call last):
[rank1]: File "run_test_grad_acc.py", line 58, in <module>
[rank1]: main()
[rank1]: File "run_test_grad_acc.py", line 48, in main
[rank1]: fabric.backward(loss)
[rank1]: File ".venv/lib/python3.10/site-packages/lightning_fabric/fabric.py", line 523, in backward
[rank1]: self._strategy.backward(tensor, module, *args, **kwargs)
[rank1]: File ".venv/lib/python3.10/site-packages/lightning_fabric/strategies/strategy.py", line 192, in backward
[rank1]: self.precision.backward(tensor, module, *args, **kwargs)
[rank1]: File ".venv/lib/python3.10/site-packages/lightning_fabric/plugins/precision/precision.py", line 107, in backward
[rank1]: tensor.backward(*args, **kwargs)
[rank1]: File ".venv/lib/python3.10/site-packages/torch/_tensor.py", line 630, in backward
[rank1]: torch.autograd.backward(
[rank1]: File ".venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank1]: _engine_run_backward(
[rank1]: File ".venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: UserWarning: The AccumulateGrad node's stream does not match the stream of the node that produced the incoming gradient. This may incur unnecessary synchronization and break CUDA graph capture if the AccumulateGrad node's stream is the default stream. This mismatch is caused by an AccumulateGrad node created prior to the current iteration being kept alive. This can happen if the autograd graph is still being kept alive by tensors such as the loss, or if you are using DDP, which will stash a reference to the node. To resolve the mismatch, delete all references to the autograd graph or ensure that DDP initialization is performed under the same stream as subsequent forwards. If the mismatch is intentional, you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this warning. (Triggered internally at /pytorch/torch/csrc/autograd/input_buffer.cpp:240.)
[rank0]: Traceback (most recent call last):
[rank0]: File "run_test_grad_acc.py", line 58, in <module>
[rank0]: main()
[rank0]: File "run_test_grad_acc.py", line 48, in main
[rank0]: fabric.backward(loss)
[rank0]: File ".venv/lib/python3.10/site-packages/lightning_fabric/fabric.py", line 523, in backward
[rank0]: self._strategy.backward(tensor, module, *args, **kwargs)
[rank0]: File ".venv/lib/python3.10/site-packages/lightning_fabric/strategies/strategy.py", line 192, in backward
[rank0]: self.precision.backward(tensor, module, *args, **kwargs)
[rank0]: File ".venv/lib/python3.10/site-packages/lightning_fabric/plugins/precision/precision.py", line 107, in backward
[rank0]: tensor.backward(*args, **kwargs)
[rank0]: File ".venv/lib/python3.10/site-packages/torch/_tensor.py", line 630, in backward
[rank0]: torch.autograd.backward(
[rank0]: File ".venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]: _engine_run_backward(
[rank0]: File ".venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: UserWarning: The AccumulateGrad node's stream does not match the stream of the node that produced the incoming gradient. This may incur unnecessary synchronization and break CUDA graph capture if the AccumulateGrad node's stream is the default stream. This mismatch is caused by an AccumulateGrad node created prior to the current iteration being kept alive. This can happen if the autograd graph is still being kept alive by tensors such as the loss, or if you are using DDP, which will stash a reference to the node. To resolve the mismatch, delete all references to the autograd graph or ensure that DDP initialization is performed under the same stream as subsequent forwards. If the mismatch is intentional, you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this warning. (Triggered internally at /pytorch/torch/csrc/autograd/input_buffer.cpp:240.)
More info
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x