Skip to content

[Question] How to Resume DeepSpeed ZeRO-2 Training with a Different Number of GPUs? #7628

@xiao10ma

Description

@xiao10ma

Hi everyone,

I'm using 🤗 Accelerate with DeepSpeed (ZeRO Stage 2) and I'm implementing a custom checkpointing mechanism for my optimizer state. My primary goal is to reliably resume training even if the number of GPUs (world_size) changes between runs.


Important Context: Why I'm Not Using accelerator.save_state

I must use a manual torch.save/torch.load approach because the DeepSpeed optimizer state dictionary (self.optimizer.state_dict()) contains non-tensor objects, specifically DeepSpeed-internal classes.

This makes the state incompatible with accelerator.save_state, which relies on safetensors and requires a purely tensor-based state dictionary. This is also why I must explicitly use weights_only=False when calling torch.load, as I understand the default will change to True in future PyTorch versions and would fail to unpickle the optimizer state.


The Problem

My current manual implementation saves and loads the optimizer state by handling each shard individually. However, this logic breaks when the number of GPUs is inconsistent between the save and load operations.

  • Scenario 1: Decreasing GPUs (e.g., save with 8 GPUs, resume with 4)
    The new run with 4 processes will only load the first 4 optimizer shards. The state corresponding to parameters managed by the original ranks 4-7 is completely lost, leading to an incorrect optimizer state.

  • Scenario 2: Increasing GPUs (e.g., save with 4 GPUs, resume with 8)
    The new processes with rank >= 4 will fail to find their corresponding shard files (optimizer.04.pt, etc.), causing the loading process to fail entirely.


Current Implementation

Here is my current save and load logic.

Saving Logic

On each process, I save its unique optimizer state shard using torch.save.

# During saving...
if accelerator.distributed_type == DistributedType.DEEPSPEED:
    # For DeepSpeed Stage 2, each process saves its own optimizer shard
    model_save_dir = "my_checkpoint_directory"
    optimizer_shard_path = os.path.join(model_save_dir, f"optimizer.{accelerator.process_index:02d}.pt")
    torch.save(self.optimizer.state_dict(), optimizer_shard_path)

Loading Logic

This logic fails when num_processes does not match the number of saved shards.

# During resuming...
if getattr(self.args, "resume_from_checkpoint", None):
    if os.path.isdir(self.args.resume_from_checkpoint):
        logger.info(f"Resuming from checkpoint {self.args.resume_from_checkpoint}")
        resume_dir = self.args.resume_from_checkpoint

        if self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
            num_processes = self.state.accelerator.num_processes
            all_shards_found = True
            optimizer_state_list = []
            
            # This loop is problematic as it assumes the number of shards on disk
            # matches the current number of processes.
            for i in range(num_processes):
                optimizer_shard_path = os.path.join(resume_dir, f"optimizer.{i:02d}.pt")
                if os.path.isfile(optimizer_shard_path):
                    # `weights_only=False` is critical for optimizer states containing non-tensor objects.
                    optimizer_state_list.append(
                        torch.load(optimizer_shard_path, map_location=self.state.accelerator.device, weights_only=False)
                    )
                else:
                    logger.warning(f"Optimizer shard for rank {i} not found. Aborting optimizer load.")
                    all_shards_found = False
                    break
            
            if all_shards_found:
                logger.info(f"Loading {len(optimizer_state_list)} optimizer state shards.")
                # This self.optimizer.load_state_dict() call is also likely incorrect for sharded states.
                self.optimizer.load_state_dict(optimizer_state_list)

My Question

What is the canonical way to solve this problem while staying with a torch.save-based approach?

  1. Consolidation & Re-sharding: Is the recommended path to first consolidate all optimizer shards into a single, complete state dictionary on CPU, and then have the new set of processes re-shard and load that consolidated state? If so, is there an established pattern or utility for this?

  2. DeepSpeed API: Does the underlying DeepSpeed engine object offer a more direct API for this? For example, a method that can take a path to a sharded checkpoint directory and handle the redistribution of states internally, even with a different world size?

  3. Best Practice: Given that I cannot use accelerator.save_state, what is the community's best practice for creating robust, reshardable optimizer checkpoints for DeepSpeed Stage 2?

Any code examples or pointers to the correct API usage would be extremely helpful.

Thank you for your time and assistance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions