diff --git a/ignite/engine/deterministic.py b/ignite/engine/deterministic.py index cfbf4c04f92b..b82e68fef0d3 100644 --- a/ignite/engine/deterministic.py +++ b/ignite/engine/deterministic.py @@ -179,6 +179,8 @@ class DeterministicEngine(Engine): def __init__(self, process_function: Callable[[Engine, Any], Any]): super(DeterministicEngine, self).__init__(process_function) self.state_dict_user_keys.append("rng_states") + if not hasattr(self.state, "rng_states"): + setattr(self.state, "rng_states", None) self.add_event_handler(Events.STARTED, self._init_run) self.add_event_handler(Events.DATALOADER_STOP_ITERATION | Events.TERMINATE_SINGLE_EPOCH, self._setup_seed) @@ -189,9 +191,6 @@ def state_dict(self) -> OrderedDict: def _init_run(self) -> None: self.state.seed = int(torch.randint(0, int(1e9), (1,)).item()) - if not hasattr(self.state, "rng_states"): - setattr(self.state, "rng_states", None) - if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False diff --git a/tests/ignite/engine/test_deterministic.py b/tests/ignite/engine/test_deterministic.py index ac618ba969b9..1d57d548eddb 100644 --- a/tests/ignite/engine/test_deterministic.py +++ b/tests/ignite/engine/test_deterministic.py @@ -1,6 +1,7 @@ import os import random import sys +from collections.abc import Mapping from unittest.mock import patch import numpy as np @@ -893,3 +894,13 @@ def test_engine_no_data_asserts(): with pytest.raises(ValueError, match=r"Deterministic engine does not support the option of data=None"): trainer.run(max_epochs=10, epoch_length=10) + + +def test_state_dict(): + engine = DeterministicEngine(lambda e, b: 1) + sd = engine.state_dict() + assert isinstance(sd, Mapping) and len(sd) == 4 + assert "iteration" in sd and sd["iteration"] == 0 + assert "max_epochs" in sd and sd["max_epochs"] is None + assert "epoch_length" in sd and sd["epoch_length"] is None + assert "rng_states" in sd and sd["rng_states"] is not None