From 4f0733b8f59ee32c5b6d7f2ddb13c4823c8026ca Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sun, 20 Feb 2022 00:23:27 +0000 Subject: [PATCH 1/2] Fixed issue when loading a single non-nn.Module object Fixed #2479 --- ignite/handlers/checkpoint.py | 20 +++++++++++--------- tests/ignite/handlers/test_checkpoint.py | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 709024ab7c26..dbd2bfe6de67 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -565,25 +565,27 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An warnings.warn("kwargs contains keys other than strict and these will be ignored") is_state_dict_strict = kwargs.get("strict", True) + + def _load_object(obj, chkpt_obj): + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + if isinstance(obj, torch.nn.Module): + obj.load_state_dict(chkpt_obj, strict=is_state_dict_strict) + else: + obj.load_state_dict(chkpt_obj) + if len(to_load) == 1: # single object and checkpoint is directly a state_dict key, obj = list(to_load.items())[0] if key not in checkpoint_obj: - if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - obj = obj.module - obj.load_state_dict(checkpoint_obj, strict=is_state_dict_strict) + _load_object(obj, checkpoint_obj) return # multiple objects to load for k, obj in to_load.items(): if k not in checkpoint_obj: raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint") - if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - obj = obj.module - if isinstance(obj, torch.nn.Module): - obj.load_state_dict(checkpoint_obj[k], strict=is_state_dict_strict) - else: - obj.load_state_dict(checkpoint_obj[k]) + _load_object(obj, checkpoint_obj[k]) def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": """Method returns state dict with saved items: list of ``(priority, filename)`` pairs. diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 12afa520d01b..8e0858107a52 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1669,3 +1669,25 @@ def test_get_default_score_fn(): score_fn = Checkpoint.get_default_score_fn("loss", -1) score = score_fn(engine) assert score == -0.123 + + +@pytest.mark.parametrize("obj_to_save", ["optim", "trainer"]) +def test_load_single_object(obj_to_save, dirname): + # Checks https://github.com/pytorch/ignite/issues/2479 + + trainer = Engine(lambda e, b: None) + if obj_to_save == "optim": + t = torch.tensor(0.0) + optim = torch.optim.SGD([t], lr=0.1) + to_save = {"optim": optim} + elif obj_to_save == "trainer": + to_save = {"trainer": trainer} + + c = Checkpoint(to_save, save_handler=dirname) + c(trainer) + + # Update this code once merged https://github.com/pytorch/ignite/pull/2461 + from pathlib import Path + + checkpoint_fp = Path(dirname) / c.last_checkpoint + Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp)) From 14643732b80da22fe987a476cb9180e3dc751b7c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sun, 20 Feb 2022 14:32:31 +0000 Subject: [PATCH 2/2] Fixed mypy issue and removed a todo --- ignite/handlers/checkpoint.py | 2 +- tests/ignite/handlers/test_checkpoint.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index d24ea6a318f1..7305a815c9e8 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -574,7 +574,7 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An is_state_dict_strict = kwargs.get("strict", True) - def _load_object(obj, chkpt_obj): + def _load_object(obj: Any, chkpt_obj: Any) -> None: if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): obj = obj.module if isinstance(obj, torch.nn.Module): diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 3edfc74deb8d..ad36ba5fbb47 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1688,8 +1688,5 @@ def test_load_single_object(obj_to_save, dirname): c = Checkpoint(to_save, save_handler=dirname) c(trainer) - # Update this code once merged https://github.com/pytorch/ignite/pull/2461 - from pathlib import Path - - checkpoint_fp = Path(dirname) / c.last_checkpoint + checkpoint_fp = dirname / c.last_checkpoint Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))