Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,25 +573,27 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An
warnings.warn("kwargs contains keys other than strict and these will be ignored")

is_state_dict_strict = kwargs.get("strict", True)

def _load_object(obj: Any, chkpt_obj: Any) -> None:
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, torch.nn.Module):
obj.load_state_dict(chkpt_obj, strict=is_state_dict_strict)
else:
obj.load_state_dict(chkpt_obj)

if len(to_load) == 1:
# single object and checkpoint is directly a state_dict
key, obj = list(to_load.items())[0]
if key not in checkpoint_obj:
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
obj.load_state_dict(checkpoint_obj, strict=is_state_dict_strict)
_load_object(obj, checkpoint_obj)
return

# multiple objects to load
for k, obj in to_load.items():
if k not in checkpoint_obj:
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, torch.nn.Module):
obj.load_state_dict(checkpoint_obj[k], strict=is_state_dict_strict)
else:
obj.load_state_dict(checkpoint_obj[k])
_load_object(obj, checkpoint_obj[k])

def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
Expand Down
19 changes: 19 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,3 +1671,22 @@ def test_get_default_score_fn():
score_fn = Checkpoint.get_default_score_fn("loss", -1)
score = score_fn(engine)
assert score == -0.123


@pytest.mark.parametrize("obj_to_save", ["optim", "trainer"])
def test_load_single_object(obj_to_save, dirname):
# Checks https://github.com/pytorch/ignite/issues/2479

trainer = Engine(lambda e, b: None)
if obj_to_save == "optim":
t = torch.tensor(0.0)
optim = torch.optim.SGD([t], lr=0.1)
to_save = {"optim": optim}
elif obj_to_save == "trainer":
to_save = {"trainer": trainer}

c = Checkpoint(to_save, save_handler=dirname)
c(trainer)

checkpoint_fp = dirname / c.last_checkpoint
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))