-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Description
Hi everyone,
I'm using 🤗 Accelerate with DeepSpeed (ZeRO Stage 2) and I'm implementing a custom checkpointing mechanism for my optimizer state. My primary goal is to reliably resume training even if the number of GPUs (world_size) changes between runs.
Important Context: Why I'm Not Using accelerator.save_state
I must use a manual torch.save/torch.load approach because the DeepSpeed optimizer state dictionary (self.optimizer.state_dict()) contains non-tensor objects, specifically DeepSpeed-internal classes.
This makes the state incompatible with accelerator.save_state, which relies on safetensors and requires a purely tensor-based state dictionary. This is also why I must explicitly use weights_only=False when calling torch.load, as I understand the default will change to True in future PyTorch versions and would fail to unpickle the optimizer state.
The Problem
My current manual implementation saves and loads the optimizer state by handling each shard individually. However, this logic breaks when the number of GPUs is inconsistent between the save and load operations.
-
Scenario 1: Decreasing GPUs (e.g., save with 8 GPUs, resume with 4)
The new run with 4 processes will only load the first 4 optimizer shards. The state corresponding to parameters managed by the original ranks 4-7 is completely lost, leading to an incorrect optimizer state. -
Scenario 2: Increasing GPUs (e.g., save with 4 GPUs, resume with 8)
The new processes withrank >= 4will fail to find their corresponding shard files (optimizer.04.pt, etc.), causing the loading process to fail entirely.
Current Implementation
Here is my current save and load logic.
Saving Logic
On each process, I save its unique optimizer state shard using torch.save.
# During saving...
if accelerator.distributed_type == DistributedType.DEEPSPEED:
# For DeepSpeed Stage 2, each process saves its own optimizer shard
model_save_dir = "my_checkpoint_directory"
optimizer_shard_path = os.path.join(model_save_dir, f"optimizer.{accelerator.process_index:02d}.pt")
torch.save(self.optimizer.state_dict(), optimizer_shard_path)Loading Logic
This logic fails when num_processes does not match the number of saved shards.
# During resuming...
if getattr(self.args, "resume_from_checkpoint", None):
if os.path.isdir(self.args.resume_from_checkpoint):
logger.info(f"Resuming from checkpoint {self.args.resume_from_checkpoint}")
resume_dir = self.args.resume_from_checkpoint
if self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
num_processes = self.state.accelerator.num_processes
all_shards_found = True
optimizer_state_list = []
# This loop is problematic as it assumes the number of shards on disk
# matches the current number of processes.
for i in range(num_processes):
optimizer_shard_path = os.path.join(resume_dir, f"optimizer.{i:02d}.pt")
if os.path.isfile(optimizer_shard_path):
# `weights_only=False` is critical for optimizer states containing non-tensor objects.
optimizer_state_list.append(
torch.load(optimizer_shard_path, map_location=self.state.accelerator.device, weights_only=False)
)
else:
logger.warning(f"Optimizer shard for rank {i} not found. Aborting optimizer load.")
all_shards_found = False
break
if all_shards_found:
logger.info(f"Loading {len(optimizer_state_list)} optimizer state shards.")
# This self.optimizer.load_state_dict() call is also likely incorrect for sharded states.
self.optimizer.load_state_dict(optimizer_state_list)My Question
What is the canonical way to solve this problem while staying with a torch.save-based approach?
-
Consolidation & Re-sharding: Is the recommended path to first consolidate all optimizer shards into a single, complete state dictionary on CPU, and then have the new set of processes re-shard and load that consolidated state? If so, is there an established pattern or utility for this?
-
DeepSpeed API: Does the underlying DeepSpeed engine object offer a more direct API for this? For example, a method that can take a path to a sharded checkpoint directory and handle the redistribution of states internally, even with a different world size?
-
Best Practice: Given that I cannot use
accelerator.save_state, what is the community's best practice for creating robust, reshardable optimizer checkpoints for DeepSpeed Stage 2?
Any code examples or pointers to the correct API usage would be extremely helpful.
Thank you for your time and assistance!