|
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | import torch.nn as nn |
| 14 | +from packaging.version import Version |
| 15 | + |
| 16 | +if Version(torch.__version__) >= Version("1.9.0"): |
| 17 | + from torch.distributed.optim import ZeroRedundancyOptimizer |
| 18 | + |
| 19 | + HAVE_ZERO = True |
| 20 | +else: |
| 21 | + HAVE_ZERO = False |
14 | 22 |
|
15 | 23 | import ignite.distributed as idist |
16 | 24 | from ignite.base import Serializable |
@@ -166,13 +174,14 @@ class Checkpoint(Serializable): |
166 | 174 | > checkpoint_12345.pt |
167 | 175 |
|
168 | 176 | Note: |
169 | | - This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 |
170 | | - process only. |
| 177 | + This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only |
| 178 | + process. This class supports automatically distributed configuration and if used with |
| 179 | + :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process. |
171 | 180 |
|
172 | 181 | .. warning:: |
173 | 182 |
|
174 | | - When running on XLA devices, it should be run in all processes, otherwise application can get stuck on |
175 | | - saving the checkpoint. |
| 183 | + When running on XLA devices or using :class:`~torch.distributed.optim.ZeroRedundancyOptimizer`, it |
| 184 | + should be run in all processes, otherwise application can get stuck while saving the checkpoint. |
176 | 185 |
|
177 | 186 | .. code-block:: python |
178 | 187 |
|
@@ -282,7 +291,7 @@ def __init__( |
282 | 291 | filename_pattern: Optional[str] = None, |
283 | 292 | include_self: bool = False, |
284 | 293 | greater_or_equal: bool = False, |
285 | | - save_on_rank: Optional[int] = 0, |
| 294 | + save_on_rank: int = 0, |
286 | 295 | ): |
287 | 296 |
|
288 | 297 | if not isinstance(to_save, collections.Mapping): |
@@ -466,6 +475,10 @@ def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]: |
466 | 475 | for k, obj in self.to_save.items(): |
467 | 476 | if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): |
468 | 477 | obj = obj.module |
| 478 | + elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer): |
| 479 | + obj.consolidate_state_dict(to=self.save_on_rank) |
| 480 | + if self.save_on_rank != idist.get_rank(): |
| 481 | + continue |
469 | 482 | checkpoint[k] = obj.state_dict() |
470 | 483 | return checkpoint |
471 | 484 |
|
@@ -782,7 +795,7 @@ def __init__( |
782 | 795 | atomic: bool = True, |
783 | 796 | create_dir: bool = True, |
784 | 797 | require_empty: bool = True, |
785 | | - save_on_rank: Optional[int] = 0, |
| 798 | + save_on_rank: int = 0, |
786 | 799 | **kwargs: Any, |
787 | 800 | ): |
788 | 801 | self.dirname = Path(dirname).expanduser() |
@@ -948,7 +961,7 @@ def __init__( |
948 | 961 | filename_pattern: Optional[str] = None, |
949 | 962 | include_self: bool = False, |
950 | 963 | greater_or_equal: bool = False, |
951 | | - save_on_rank: Optional[int] = 0, |
| 964 | + save_on_rank: int = 0, |
952 | 965 | **kwargs: Any, |
953 | 966 | ): |
954 | 967 |
|
|
0 commit comments