Skip to content

KeyError: "A parameter in the optimizer couldn't be switched to its sharded version" occurs under specific Accelerate configuration #3821

@sdjasj

Description

@sdjasj

System Info

Copy-and-paste the text below in your GitHub issue                                                                                  

- `Accelerate` version: 1.12.0.dev0
- Platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate
- Python version: 3.10.19
- Numpy version: 2.2.6
- PyTorch version: 2.9.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 1007.66 GB
- GPU type: NVIDIA GeForce RTX 4090

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

When I run my training code using my customized configuration, I encounter the following error:
KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."

The problematic configuration file is failed_config.yaml:

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
num_machines: 1
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_processes: 4
use_cpu: false
debug: false
enable_cpu_affinity: false
downcast_bf16: fp16
fp8_config: null
rdzv_backend: static
same_network: true
fsdp_config:
  fsdp_version: 2
  fsdp_offload_params: false
  fsdp_cpu_ram_efficient_loading: false
  fsdp_activation_checkpointing: false
  fsdp_reshard_after_forward: false
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_auto_wrap_policy: NO_WRAP
parallelism_config:
  parallelism_config_dp_replicate_size: 1
  parallelism_config_dp_shard_size: 2
  parallelism_config_tp_size: 2
  parallelism_config_cp_size: 1
  parallelism_config_cp_comm_strategy: alltoall

The training script main.py is as follows:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator

class RandomDataset(Dataset):
    def __init__(self, num_samples=100, input_dim=128, num_classes=10):
        self.num_samples = num_samples
        self.input_dim = input_dim
        self.num_classes = num_classes

    def __getitem__(self, idx):
        x = torch.randn(self.input_dim)
        y = torch.randint(0, self.num_classes, (1,)).item()
        return x, y

    def __len__(self):
        return self.num_samples


class MLP10(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=512, num_classes=10):
        super().__init__()
        layers = []
        in_dim = input_dim
        for _ in range(10):
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
            in_dim = hidden_dim
        layers.append(nn.Linear(hidden_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def main():
    accelerator = Accelerator()

    input_dim = 128
    hidden_dim = 512
    num_classes = 10
    batch_size = 64
    num_epochs = 1
    lr = 1e-3

    accelerator.print("🚀 Starting Accelerate MLP training")
    accelerator.print(f"Using device: {accelerator.device}")

    dataset = RandomDataset(num_samples=64 * 4 * 20, input_dim=input_dim, num_classes=num_classes)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    model = MLP10(input_dim, hidden_dim, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for step, (inputs, targets) in enumerate(dataloader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            accelerator.backward(loss)
            optimizer.step()

            total_loss += loss.item()
            accelerator.print(
                f"[Epoch {epoch+1}/{num_epochs}] Step {step+1}/{len(dataloader)} "
                f"Loss: {total_loss / (step+1):.4f}"
            )


        accelerator.print(f"✅ Epoch {epoch+1} finished. Avg Loss: {total_loss / len(dataloader):.4f}")

    accelerator.print("🎉 Training completed successfully!")

if __name__ == "__main__":
    main()

You can reproduce the issue by running the following command:

accelerate launch --config_file /home/yanzhen/distributed_test/accelerate/test/bug1/failed_config.yaml main.py

The following error log will then appear:

/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/launch.py:238: UserWarning: Port `29500` is already in use. Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. If this current attempt fails, or for more control in future runs, please specify a different port (e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection in your launch command or Accelerate config file.
  warnings.warn(
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 2 is connected to 13 is connected to  peer ranks. 3Expected number of connected peer ranks is :  peer ranks. 3Expected number of connected peer ranks is : 
3
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 1[Gloo] Rank  is connected to 10 peer ranks.  is connected to Expected number of connected peer ranks is : 11 peer ranks. Expected number of connected peer ranks is : 
1
[Gloo] Rank [Gloo] Rank 1 is connected to 01 is connected to  peer ranks. 1Expected number of connected peer ranks is :  peer ranks. 1Expected number of connected peer ranks is : 
1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 01 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1[Gloo] Rank 
0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 01 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 10 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
🚀 Starting Accelerate MLP training
Using device: cuda:0
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank2]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank2]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank2]: KeyError: 139264640

[rank2]: During handling of the above exception, another exception occurred:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank2]:     main()
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank2]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank2]:     result = self._prepare_fsdp2(*args)
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank2]:     fsdp2_switch_optimizer_parameters(obj, mapping)
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank2]:     raise KeyError(
[rank2]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank3]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank3]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank3]: KeyError: 147864128

[rank3]: During handling of the above exception, another exception occurred:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank3]:     main()
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank3]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank3]:     result = self._prepare_fsdp2(*args)
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank3]:     fsdp2_switch_optimizer_parameters(obj, mapping)
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank3]:     raise KeyError(
[rank3]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank1]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank1]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank1]: KeyError: 165715456

[rank1]: During handling of the above exception, another exception occurred:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank1]:     main()
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank1]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank1]:     result = self._prepare_fsdp2(*args)
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank1]:     fsdp2_switch_optimizer_parameters(obj, mapping)
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank1]:     raise KeyError(
[rank1]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py:710: UserWarning: FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints.
  warnings.warn(
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in fsdp2_switch_optimizer_parameters
[rank0]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 560, in <listcomp>
[rank0]:     param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
[rank0]: KeyError: 162314368

[rank0]: During handling of the above exception, another exception occurred:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 84, in <module>
[rank0]:     main()
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/test/bug1/main.py", line 59, in main
[rank0]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank0]:     result = self._prepare_fsdp2(*args)
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1703, in _prepare_fsdp2
[rank0]:     fsdp2_switch_optimizer_parameters(obj, mapping)
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 564, in fsdp2_switch_optimizer_parameters
[rank0]:     raise KeyError(
[rank0]: KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
[rank2]:[W1030 16:08:14.055374353 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank0]:[W1030 16:08:14.085499382 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W1030 16:08:14.095634764 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1030 16:08:15.096000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3812798 closing signal SIGTERM
W1030 16:08:15.096000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3812799 closing signal SIGTERM
W1030 16:08:15.097000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3812801 closing signal SIGTERM
E1030 16:08:15.411000 3812514 site-packages/torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 2 (pid: 3812800) of binary: /home/yanzhen/miniconda3/envs/accelerate_test/bin/python3.10
Traceback (most recent call last):
  File "/home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate", line 7, in <module>
    sys.exit(main())
  File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/accelerate_cli.py", line 50, in main
    args.func(args)
  File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 1222, in launch_command
    multi_gpu_launcher(args)
  File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 853, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/run.py", line 927, in run
    elastic_launch(
  File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
main.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-10-30_16:08:15
  host      : ubuntu
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 3812800)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

However, when I remove the following configuration section from failed_config.yaml, the error no longer occurs:

parallelism_config:
  parallelism_config_dp_replicate_size: 1
  parallelism_config_dp_shard_size: 2
  parallelism_config_tp_size: 2
  parallelism_config_cp_size: 1
  parallelism_config_cp_comm_strategy: alltoall

Expected behavior

KeyError: "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub." should not occur.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions