Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
filename_prefix: str = "",
score_function: Optional[Callable] = None,
score_name: Optional[str] = None,
n_saved: Optional[int] = 1,
n_saved: Union[int, None] = 1,
global_step_transform: Optional[Callable] = None,
filename_pattern: Optional[str] = None,
include_self: bool = False,
Expand Down Expand Up @@ -358,7 +358,11 @@ def reset(self) -> None:
def last_checkpoint(self) -> Optional[Union[str, Path]]:
if len(self._saved) < 1:
return None
return self._saved[-1].filename

if not isinstance(self.save_handler, DiskSaver):
return self._saved[-1].filename

return self.save_handler.dirname / self._saved[-1].filename

def _check_lt_n_saved(self, or_equal: bool = False) -> bool:
if self.n_saved is None:
Expand Down Expand Up @@ -798,13 +802,22 @@ class ModelCheckpoint(Checkpoint):
Input of the function is `(engine, event_name)`. Output of function should be an integer.
Default is None, global_step based on attached engine. If provided, uses function output as global_step.
To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`.
filename_pattern: If ``filename_pattern`` is provided, this pattern will be used to render
checkpoint filenames. If the pattern is not defined, the default pattern would be used.
See :class:`~ignite.handlers.checkpoint.Checkpoint` for details.
include_self: Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
there must not be another object in ``to_save`` with key ``checkpointer``.
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
Default, `False`.
kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.

.. versionchanged:: 0.4.2
Accept ``kwargs`` for `torch.save` or `xm.save`

.. versionchanged:: 0.5.0
Accept ``filename_pattern`` and ``greater_or_equal`` for parity
with :class:`~ignite.handlers.checkpoint.Checkpoint`

Examples:
.. code-block:: python

Expand All @@ -826,15 +839,17 @@ class ModelCheckpoint(Checkpoint):
def __init__(
self,
dirname: Union[str, Path],
filename_prefix: str,
filename_prefix: str = "",
score_function: Optional[Callable] = None,
score_name: Optional[str] = None,
n_saved: Union[int, None] = 1,
atomic: bool = True,
require_empty: bool = True,
create_dir: bool = True,
global_step_transform: Optional[Callable] = None,
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
**kwargs: Any,
):

Expand All @@ -848,7 +863,9 @@ def __init__(
score_name=score_name,
n_saved=n_saved,
global_step_transform=global_step_transform,
filename_pattern=filename_pattern,
include_self=include_self,
greater_or_equal=greater_or_equal,
)

@property
Expand All @@ -857,9 +874,7 @@ def last_checkpoint(self) -> Optional[Union[str, Path]]:
return None

if not isinstance(self.save_handler, DiskSaver):
raise RuntimeError(
f"Unable to save checkpoint, save_handler should be DiskSaver, got {type(self.save_handler)}."
)
raise RuntimeError(f"Internal error, save_handler should be DiskSaver, but has {type(self.save_handler)}.")

return self.save_handler.dirname / self._saved[-1].filename

Expand Down
177 changes: 133 additions & 44 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def test_model_checkpoint_invalid_save_handler(dirname):
h(Engine(lambda x, y: None), to_save)

with pytest.raises(
RuntimeError, match=rf"Unable to save checkpoint, save_handler should be DiskSaver, got {type(h.save_handler)}."
RuntimeError, match=rf"Internal error, save_handler should be DiskSaver, but has {type(h.save_handler)}."
):
h.last_checkpoint

Expand Down Expand Up @@ -1063,7 +1063,7 @@ def _check_state_dict(original, loaded):
# If Checkpoint's state was restored correctly, it should continue to respect n_saved
# and delete old checkpoints, and have the correct last_checkpoint.
assert os.listdir(dirname) == ["checkpoint_4.pt"]
assert checkpointer2.last_checkpoint == "checkpoint_4.pt"
assert checkpointer2.last_checkpoint == dirname / "checkpoint_4.pt"


def test_save_model_optimizer_lr_scheduler_with_validation(dirname):
Expand Down Expand Up @@ -1336,61 +1336,109 @@ def test_distrib_xla_nprocs(xmp_executor, dirname):
xmp_executor(_test_tpu_saves_to_cpu_nprocs, args=(dirname,), nprocs=n)


def test_checkpoint_filename_pattern():
def _test(
def _test_checkpoint_filename_pattern_helper(
to_save,
filename_prefix="",
score_function=None,
score_name=None,
global_step_transform=None,
filename_pattern=None,
dirname=None,
):
save_handler = MagicMock(spec=BaseSaveHandler)

checkpointer = Checkpoint(
to_save,
filename_prefix="",
score_function=None,
score_name=None,
global_step_transform=None,
filename_pattern=None,
):
save_handler = MagicMock(spec=BaseSaveHandler)
save_handler=save_handler,
filename_prefix=filename_prefix,
score_function=score_function,
score_name=score_name,
global_step_transform=global_step_transform,
filename_pattern=filename_pattern,
)

checkpointer = Checkpoint(
to_save,
save_handler=save_handler,
filename_prefix=filename_prefix,
score_function=score_function,
score_name=score_name,
global_step_transform=global_step_transform,
filename_pattern=filename_pattern,
)
trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=12, iteration=203, score=0.9999)

trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=12, iteration=203, score=0.9999)
checkpointer(trainer)
return checkpointer.last_checkpoint


def _test_model_checkpoint_filename_pattern_helper(
to_save,
filename_prefix="",
score_function=None,
score_name=None,
global_step_transform=None,
filename_pattern=None,
dirname=None,
):
checkpointer = ModelCheckpoint(
dirname=dirname,
filename_prefix=filename_prefix,
score_function=score_function,
score_name=score_name,
global_step_transform=global_step_transform,
filename_pattern=filename_pattern,
require_empty=False,
)

checkpointer(trainer)
return checkpointer.last_checkpoint
trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=12, iteration=203, score=0.9999)

checkpointer(trainer, to_save)
return Path(checkpointer.last_checkpoint).name


@pytest.mark.parametrize("test_class", ["checkpoint", "model_checkpoint"])
def test_checkpoint_filename_pattern(test_class, dirname):

if test_class == "checkpoint":
_test = _test_checkpoint_filename_pattern_helper
elif test_class == "model_checkpoint":
_test = _test_model_checkpoint_filename_pattern_helper

model = DummyModel()
to_save = {"model": model}

assert _test(to_save) == "model_203.pt"
assert _test(to_save, "best") == "best_model_203.pt"
assert _test(to_save, score_function=lambda e: e.state.score) == "model_0.9999.pt"
assert _test(to_save, dirname=dirname) == "model_203.pt"
assert _test(to_save, "best", dirname=dirname) == "best_model_203.pt"
assert _test(to_save, score_function=lambda e: e.state.score, dirname=dirname) == "model_0.9999.pt"

res = _test(to_save, score_function=lambda e: e.state.score, global_step_transform=lambda e, _: e.state.epoch)
res = _test(
to_save,
score_function=lambda e: e.state.score,
global_step_transform=lambda e, _: e.state.epoch,
dirname=dirname,
)
assert res == "model_12_0.9999.pt"

assert _test(to_save, score_function=lambda e: e.state.score, score_name="acc") == "model_acc=0.9999.pt"
assert (
_test(to_save, score_function=lambda e: e.state.score, score_name="acc", dirname=dirname)
== "model_acc=0.9999.pt"
)

res = _test(
to_save,
score_function=lambda e: e.state.score,
score_name="acc",
global_step_transform=lambda e, _: e.state.epoch,
dirname=dirname,
)
assert res == "model_12_acc=0.9999.pt"

assert _test(to_save, "best", score_function=lambda e: e.state.score) == "best_model_0.9999.pt"
assert _test(to_save, "best", score_function=lambda e: e.state.score, dirname=dirname) == "best_model_0.9999.pt"

res = _test(
to_save, "best", score_function=lambda e: e.state.score, global_step_transform=lambda e, _: e.state.epoch
to_save,
"best",
score_function=lambda e: e.state.score,
global_step_transform=lambda e, _: e.state.epoch,
dirname=dirname,
)
assert res == "best_model_12_0.9999.pt"

res = _test(to_save, "best", score_function=lambda e: e.state.score, score_name="acc")
res = _test(to_save, "best", score_function=lambda e: e.state.score, score_name="acc", dirname=dirname)
assert res == "best_model_acc=0.9999.pt"

res = _test(
Expand All @@ -1399,29 +1447,36 @@ def _test(
score_function=lambda e: e.state.score,
score_name="acc",
global_step_transform=lambda e, _: e.state.epoch,
dirname=dirname,
)
assert res == "best_model_12_acc=0.9999.pt"

pattern = "{name}.{ext}"
assert _test(to_save, filename_pattern=pattern) == "model.pt"
assert _test(to_save, filename_pattern=pattern, dirname=dirname) == "model.pt"

pattern = "chk-{name}--{global_step}.{ext}"
assert _test(to_save, to_save, filename_pattern=pattern) == "chk-model--203.pt"
assert _test(to_save, to_save, filename_pattern=pattern, dirname=dirname) == "chk-model--203.pt"
pattern = "chk-{filename_prefix}--{name}--{global_step}.{ext}"
assert _test(to_save, "best", filename_pattern=pattern) == "chk-best--model--203.pt"
assert _test(to_save, "best", filename_pattern=pattern, dirname=dirname) == "chk-best--model--203.pt"
pattern = "chk-{name}--{score}.{ext}"
assert _test(to_save, score_function=lambda e: e.state.score, filename_pattern=pattern) == "chk-model--0.9999.pt"
assert (
_test(to_save, score_function=lambda e: e.state.score, filename_pattern=pattern, dirname=dirname)
== "chk-model--0.9999.pt"
)
pattern = "{global_step}-{name}-{score}.chk.{ext}"
res = _test(
to_save,
score_function=lambda e: e.state.score,
global_step_transform=lambda e, _: e.state.epoch,
filename_pattern=pattern,
dirname=dirname,
)
assert res == "12-model-0.9999.chk.pt"

pattern = "chk-{name}--{score_name}--{score}.{ext}"
res = _test(to_save, score_function=lambda e: e.state.score, score_name="acc", filename_pattern=pattern)
res = _test(
to_save, score_function=lambda e: e.state.score, score_name="acc", filename_pattern=pattern, dirname=dirname
)
assert res == "chk-model--acc--0.9999.pt"

pattern = "chk-{name}-{global_step}-{score_name}-{score}.{ext}"
Expand All @@ -1431,11 +1486,12 @@ def _test(
score_name="acc",
global_step_transform=lambda e, _: e.state.epoch,
filename_pattern=pattern,
dirname=dirname,
)
assert res == "chk-model-12-acc-0.9999.pt"

pattern = "{filename_prefix}-{name}-{score}.chk"
res = _test(to_save, "best", score_function=lambda e: e.state.score, filename_pattern=pattern)
res = _test(to_save, "best", score_function=lambda e: e.state.score, filename_pattern=pattern, dirname=dirname)
assert res == "best-model-0.9999.chk"

pattern = "resnet-{filename_prefix}-{name}-{global_step}-{score}.chk"
Expand All @@ -1445,11 +1501,19 @@ def _test(
score_function=lambda e: e.state.score,
global_step_transform=lambda e, _: e.state.epoch,
filename_pattern=pattern,
dirname=dirname,
)
assert res == "resnet-best-model-12-0.9999.chk"

pattern = "{filename_prefix}-{name}-{score_name}-{score}.chk"
res = _test(to_save, "best", score_function=lambda e: e.state.score, score_name="acc", filename_pattern=pattern)
res = _test(
to_save,
"best",
score_function=lambda e: e.state.score,
score_name="acc",
filename_pattern=pattern,
dirname=dirname,
)
assert res == "best-model-acc-0.9999.chk"

pattern = "{global_step}-{filename_prefix}-{name}-{score_name}-{score}"
Expand All @@ -1460,27 +1524,29 @@ def _test(
score_name="acc",
global_step_transform=lambda e, _: e.state.epoch,
filename_pattern=pattern,
dirname=dirname,
)
assert res == "12-best-model-acc-0.9999"

pattern = "SAVE:{name}-{score_name}-{score}.pth"
pattern = "SAVE-{name}-{score_name}-{score}.pth"
res = _test(
to_save,
"best",
score_function=lambda e: e.state.score,
score_name="acc",
global_step_transform=lambda e, _: e.state.epoch,
filename_pattern=pattern,
dirname=dirname,
)

assert res == "SAVE:model-acc-0.9999.pth"
assert res == "SAVE-model-acc-0.9999.pth"

pattern = "{global_step}-chk-{filename_prefix}-{name}-{score_name}-{score}.{ext}"
assert _test(to_save, filename_pattern=pattern) == "203-chk--model-None-None.pt"
assert _test(to_save, filename_pattern=pattern, dirname=dirname) == "203-chk--model-None-None.pt"

with pytest.raises(KeyError, match=r"random_key"):
pattern = "SAVE:{random_key}.{ext}"
_test(to_save, filename_pattern=pattern)
pattern = "SAVE-{random_key}.{ext}"
_test(to_save, filename_pattern=pattern, dirname=dirname)


def test_setup_filename_pattern():
Expand Down Expand Up @@ -1655,6 +1721,29 @@ def __call__(self, c, f, m):
assert handler.counter == 4


def test_greater_or_equal_model_checkpoint(dirname):
scores = iter([1, 2, 2, 2])

def score_function(_):
return next(scores)

checkpointer = ModelCheckpoint(
dirname,
score_function=score_function,
n_saved=2,
greater_or_equal=True,
)
trainer = Engine(lambda e, b: None)

to_save = {"model": DummyModel()}
for i in range(4):
checkpointer(trainer, to_save)
if i == 0:
assert Path(checkpointer.last_checkpoint).name == "model_1.pt"
else:
assert Path(checkpointer.last_checkpoint).name == "model_2.pt"


def test_get_default_score_fn():

with pytest.raises(ValueError, match=r"Argument score_sign should be 1 or -1"):
Expand Down