-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Description
System Info
Copy-and-paste the text below in your GitHub issue
- `Accelerate` version: 1.12.0.dev0
- Platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate
- Python version: 3.10.19
- Numpy version: 2.2.6
- PyTorch version: 2.9.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 1007.66 GB
- GPU type: NVIDIA GeForce RTX 4090Information
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - My own task or dataset (give details below)
Reproduction
When I run my training code using my customized configuration, I encounter the following error:
KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
The problematic configuration file is failed_config.yaml:
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
num_machines: 1
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_processes: 4
use_cpu: false
debug: false
enable_cpu_affinity: false
downcast_bf16: fp16
fp8_config: null
rdzv_backend: static
same_network: true
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_activation_checkpointing: false
fsdp_reshard_after_forward: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: NO_WRAP
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 2
parallelism_config_tp_size: 2
parallelism_config_cp_size: 1
parallelism_config_cp_comm_strategy: alltoall
The training script main.py is as follows:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
class RandomDataset(Dataset):
def __init__(self, num_samples=100, input_dim=128, num_classes=10):
self.num_samples = num_samples
self.input_dim = input_dim
self.num_classes = num_classes
def __getitem__(self, idx):
x = torch.randn(self.input_dim)
y = torch.randint(0, self.num_classes, (1,)).item()
return x, y
def __len__(self):
return self.num_samples
class MLP10(nn.Module):
def __init__(self, input_dim=128, hidden_dim=512, num_classes=10):
super().__init__()
layers = []
in_dim = input_dim
for _ in range(10):
layers.append(nn.Linear(in_dim, hidden_dim))
layers.append(nn.ReLU())
in_dim = hidden_dim
layers.append(nn.Linear(hidden_dim, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def main():
accelerator = Accelerator()
input_dim = 128
hidden_dim = 512
num_classes = 10
batch_size = 64
num_epochs = 1
lr = 1e-3
accelerator.print("🚀 Starting Accelerate MLP training")
accelerator.print(f"Using device: {accelerator.device}")
dataset = RandomDataset(num_samples=64 * 4 * 20, input_dim=input_dim, num_classes=num_classes)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
model = MLP10(input_dim, hidden_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
for step, (inputs, targets) in enumerate(dataloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
total_loss += loss.item()
accelerator.print(
f"[Epoch {epoch+1}/{num_epochs}] Step {step+1}/{len(dataloader)} "
f"Loss: {total_loss / (step+1):.4f}"
)
accelerator.print(f"✅ Epoch {epoch+1} finished. Avg Loss: {total_loss / len(dataloader):.4f}")
accelerator.print("🎉 Training completed successfully!")
if __name__ == "__main__":
main()You can reproduce the issue by running the following command:
accelerate launch --config_file /home/yanzhen/distributed_test/accelerate/test/bug1/failed_config.yaml main.py
The following error log will then appear:
/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/launch.py:238: UserWarning: Port `29500` is already in use. Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. If this current attempt fails, or for more control in future runs, please specify a different port (e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection in your launch command or Accelerate config file.
warnings.warn(
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 2 is connected to 13 is connected to peer ranks. 3Expected number of connected peer ranks is : peer ranks. 3Expected number of connected peer ranks is :
3
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 1[Gloo] Rank is connected to 10 peer ranks. is connected to Expected number of connected peer ranks is : 11 peer ranks. Expected number of connected peer ranks is :
1
[Gloo] Rank [Gloo] Rank 1 is connected to 01 is connected to peer ranks. 1Expected number of connected peer ranks is : peer ranks. 1Expected number of connected peer ranks is :
1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 01 is connected to is connected to 11 peer ranks. peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1[Gloo] Rank
0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 01 is connected to is connected to 11 peer ranks. peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 10 is connected to is connected to 11 peer ranks. peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
🚀 Starting Accelerate MLP training
Using device: cuda:0
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank2]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank2]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank2]: KeyError: 139264640
[rank2]: During handling of the above exception, another exception occurred:
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank2]: main()
[rank2]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank2]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank2]: result = self._prepare_fsdp2(*args)
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank2]: fsdp2_switch_optimizer_parameters(obj, mapping)
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank2]: raise KeyError(
[rank2]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank3]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank3]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank3]: KeyError: 147864128
[rank3]: During handling of the above exception, another exception occurred:
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank3]: main()
[rank3]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank3]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank3]: result = self._prepare_fsdp2(*args)
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank3]: fsdp2_switch_optimizer_parameters(obj, mapping)
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank3]: raise KeyError(
[rank3]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank1]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank1]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank1]: KeyError: 165715456
[rank1]: During handling of the above exception, another exception occurred:
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank1]: main()
[rank1]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank1]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank1]: result = self._prepare_fsdp2(*args)
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank1]: fsdp2_switch_optimizer_parameters(obj, mapping)
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank1]: raise KeyError(
[rank1]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py:710: UserWarning: FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints.
warnings.warn(
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank0]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank0]: param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank0]: KeyError: 162314368
[rank0]: During handling of the above exception, another exception occurred:
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank0]: main()
[rank0]: File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank0]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank0]: result = self._prepare_fsdp2(*args)
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank0]: fsdp2_switch_optimizer_parameters(obj, mapping)
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank0]: raise KeyError(
[rank0]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
[rank2]:[W1030 16:08:14.055374353 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank0]:[W1030 16:08:14.085499382 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W1030 16:08:14.095634764 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1030 16:08:15.096000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3812798 closing signal SIGTERM
W1030 16:08:15.096000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3812799 closing signal SIGTERM
W1030 16:08:15.097000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3812801 closing signal SIGTERM
E1030 16:08:15.411000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 2 (pid: 3812800) of binary: /home/yanzhen/miniconda3/envs/accelerate_test/bin/python3.10
Traceback (most recent call last):
File "/home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate", line 7, in <module>
sys.exit(main())
File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/accelerate_cli.py", line 50, in main
args.func(args)
File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 1222, in launch_command
multi_gpu_launcher(args)
File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 853, in multi_gpu_launcher
distrib_run.run(args)
File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/run.py", line 927, in run
elastic_launch(
File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
main.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-10-30_16:08:15
host : ubuntu
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 3812800)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
However, when I remove the following configuration section from failed_config.yaml, the error no longer occurs:
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 2
parallelism_config_tp_size: 2
parallelism_config_cp_size: 1
parallelism_config_cp_comm_strategy: alltoall
Expected behavior
KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub." should not occur.
Yyhhh6 and naomili0924
Metadata
Metadata
Assignees
Labels
No labels