Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@

import torch
import torch.nn as nn
from packaging.version import Version

if Version(torch.__version__) >= Version("1.9.0"):
from torch.distributed.optim import ZeroRedundancyOptimizer

HAVE_ZERO = True
else:
HAVE_ZERO = False

import ignite.distributed as idist
from ignite.base import Serializable
Expand Down Expand Up @@ -166,13 +174,14 @@ class Checkpoint(Serializable):
> checkpoint_12345.pt

Note:
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0
process only.
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
process. This class supports automatically distributed configuration and if used with
:class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.

.. warning::

When running on XLA devices, it should be run in all processes, otherwise application can get stuck on
saving the checkpoint.
When running on XLA devices or using :class:`~torch.distributed.optim.ZeroRedundancyOptimizer`, it
should be run in all processes, otherwise application can get stuck while saving the checkpoint.

.. code-block:: python

Expand Down Expand Up @@ -282,7 +291,7 @@ def __init__(
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
save_on_rank: Optional[int] = 0,
save_on_rank: int = 0,
):

if not isinstance(to_save, collections.Mapping):
Expand Down Expand Up @@ -466,6 +475,10 @@ def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
for k, obj in self.to_save.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer):
obj.consolidate_state_dict(to=self.save_on_rank)
if self.save_on_rank != idist.get_rank():
continue
checkpoint[k] = obj.state_dict()
return checkpoint

Expand Down Expand Up @@ -782,7 +795,7 @@ def __init__(
atomic: bool = True,
create_dir: bool = True,
require_empty: bool = True,
save_on_rank: Optional[int] = 0,
save_on_rank: int = 0,
**kwargs: Any,
):
self.dirname = Path(dirname).expanduser()
Expand Down Expand Up @@ -948,7 +961,7 @@ def __init__(
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
save_on_rank: Optional[int] = 0,
save_on_rank: int = 0,
**kwargs: Any,
):

Expand Down
35 changes: 34 additions & 1 deletion tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,9 +1243,37 @@ def _test_checkpoint_load_objects_ddp(device):
Checkpoint.load_objects(to_load, checkpoint)


def _test_checkpoint_with_ZeRO(device, dirname, local_rank):

from torch.distributed.optim import ZeroRedundancyOptimizer

model = DummyModel().to(device)
opt = ZeroRedundancyOptimizer(model.parameters(), torch.optim.SGD, lr=0.01)
mocked_opt = MagicMock(ZeroRedundancyOptimizer, wraps=opt)

# A `step` should be called to optimizer state get populated.
out = model(torch.Tensor([1.0]))
out.backward()
mocked_opt.step()

to_save = {"model": model, "optim": mocked_opt}
checkpointer = Checkpoint(to_save, dirname, save_on_rank=1)

engine = Engine(lambda e, b: None)
checkpointer(engine)

mocked_opt.consolidate_state_dict.assert_called_once_with(to=1)

if local_rank == 1:

loaded_state_dict = torch.load(dirname / "checkpoint_0.pt", map_location=device)["optim"]
state_dict = opt.state_dict()
assert loaded_state_dict == state_dict


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_zero_dirname):
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, dirname, get_rank_zero_dirname, local_rank):

device = idist.device()
rank_zero_dirname = get_rank_zero_dirname()
Expand All @@ -1254,6 +1282,11 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_
_test_checkpoint_with_ddp(device)
_test_checkpoint_load_objects_ddp(device)

from ignite.handlers.checkpoint import HAVE_ZERO

if HAVE_ZERO:
_test_checkpoint_with_ZeRO(device, dirname, local_rank)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
Expand Down