Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ignite/engine/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def _init_run(self) -> None:

def _setup_engine(self) -> None:
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
raise ValueError(
"Deterministic engine does not support the option of data=None. Please, provide data as iterable"
)

self._dataloader_len = self._get_data_length(self.state.dataloader)
Expand Down
47 changes: 31 additions & 16 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,12 @@ def switch_dataloader():

def run(
self,
data: Iterable,
data: Optional[Iterable] = None,
max_epochs: Optional[int] = None,
max_iters: Optional[int] = None,
epoch_length: Optional[int] = None,
) -> State:
"""Runs the `process_function` over the passed data.
"""Runs the ``process_function`` over the passed data.

Engine has a state and the following logic is applied in this function:

Expand All @@ -617,7 +617,8 @@ def run(
- If state is defined, engine is NOT "done", then input arguments if provided override defined state.

Args:
data: Collection of batches allowing repeated iteration (e.g., list or `DataLoader`).
data: Collection of batches allowing repeated iteration (e.g., list or `DataLoader`). If not provided, then
``epoch_length`` is required and ``batch`` argument of ``process_function`` will be ``None``.
max_epochs: Max epochs to run for (default: None).
If a new state should be created (first run or run again from ended engine), it's default value is 1.
If run is resuming from a state, provided `max_epochs` will be taken into account and should be larger
Expand Down Expand Up @@ -656,7 +657,7 @@ def switch_batch(engine):
trainer.run(train_loader, max_epochs=2)

"""
if not isinstance(data, Iterable):
if data is not None and not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")

if self.state.max_epochs is not None:
Expand All @@ -680,6 +681,9 @@ def switch_batch(engine):
if self.state.max_epochs is None or self._is_done(self.state):
# Create new state
if epoch_length is None:
if data is None:
raise ValueError("epoch_length should be provided if data is None")

epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
Expand Down Expand Up @@ -707,6 +711,8 @@ def switch_batch(engine):
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)
if self.state.epoch_length is None and data is None:
raise ValueError("epoch_length should be provided if data is None")

self.state.dataloader = data
return self._internal_run()
Expand All @@ -725,14 +731,20 @@ def _get_data_length(self, data: Iterable) -> Optional[int]:
pass
return None

def _setup_engine(self) -> None:
def _setup_dataloader_iter(self) -> None:
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
)
if self.state.epoch_length is None:
raise RuntimeError(
"Internal error, self.state.epoch_length is None. "
"Please, file an issue if you encounter this error."
)
self._dataloader_iter = _get_none_data_iter(self.state.epoch_length)
else:
self._dataloader_iter = iter(self.state.dataloader)

def _setup_engine(self) -> None:
self._setup_dataloader_iter()
iteration = self.state.iteration
self._dataloader_iter = iter(self.state.dataloader)

# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
if self.state.epoch_length is not None:
Expand Down Expand Up @@ -796,11 +808,8 @@ def _run_once_on_dataset(self) -> float:
try:
if self._dataloader_iter is None:
raise RuntimeError(
"Internal error, self._dataloader_iter is None. Please, file an issue if you encounter this error."
)
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
"Internal error, self._dataloader_iter is None. "
"Please, file an issue if you encounter this error."
)

while True:
Expand Down Expand Up @@ -839,7 +848,7 @@ def _run_once_on_dataset(self) -> float:
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
self.set_data(self.state.dataloader)
self._setup_dataloader_iter()

should_exit = True

Expand All @@ -853,7 +862,7 @@ def _run_once_on_dataset(self) -> float:
if self.should_terminate or self.should_terminate_single_epoch:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self.set_data(self.state.dataloader)
self._setup_dataloader_iter()
break

if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
Expand All @@ -868,3 +877,9 @@ def _run_once_on_dataset(self) -> float:
self._handle_exception(e)

return time.time() - start_time


def _get_none_data_iter(size: int) -> Iterator:
# Sized iterator for data as None
for _ in range(size):
yield None
7 changes: 7 additions & 0 deletions tests/ignite/engine/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,10 @@ def test__set_rng_states_cuda():
rng_states = [random.getstate(), torch.get_rng_state().cuda(), np.random.get_state()]
_set_rng_states(rng_states)
assert rng_states[1].device.type == "cpu"


def test_engine_no_data_asserts():
trainer = DeterministicEngine(lambda e, b: None)

with pytest.raises(ValueError, match=r"Deterministic engine does not support the option of data=None"):
trainer.run(max_epochs=10, epoch_length=10)
94 changes: 65 additions & 29 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,35 @@ def data():
engine.run(data)


def test_current_epoch_counter_increases_every_epoch():
@pytest.mark.parametrize("data", [None, [1, 2]])
def test_current_epoch_counter_increases_every_epoch(data):
engine = Engine(MagicMock(return_value=1))
max_epochs = 5

counter = EpochCounter()
engine.add_event_handler(Events.EPOCH_STARTED, counter)

state = engine.run([1, 2], max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=2)
assert state.epoch == max_epochs
counter.current_epoch_count = 1
state = engine.run([1, 2], max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=2)
assert state.epoch == max_epochs


def test_current_iteration_counter_increases_every_iteration():
batches = [1, 2, 3]
@pytest.mark.parametrize("data", [None, [1, 2, 3]])
def test_current_iteration_counter_increases_every_iteration(data):
engine = Engine(MagicMock(return_value=1))
max_epochs = 5

counter = IterationCounter()
engine.add_event_handler(Events.ITERATION_STARTED, counter)

state = engine.run(batches, max_epochs=max_epochs)
assert state.iteration == max_epochs * len(batches)
epoch_length = 3
state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)
assert state.iteration == max_epochs * epoch_length
counter.current_iteration_count = 1
state = engine.run(batches, max_epochs=max_epochs)
assert state.iteration == max_epochs * len(batches)
state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)
assert state.iteration == max_epochs * epoch_length


def test_stopping_criterion_is_max_epochs():
Expand All @@ -80,7 +82,8 @@ def test_stopping_criterion_is_max_epochs():
assert state.epoch == max_epochs


def test_terminate_at_end_of_epoch_stops_run():
@pytest.mark.parametrize("data", [None, [1, 2]])
def test_terminate_at_end_of_epoch_stops_run(data):
max_epochs = 5
last_epoch_to_run = 3

Expand All @@ -94,16 +97,17 @@ def end_of_epoch_handler(engine):

assert not engine.should_terminate

state = engine.run([1], max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=2)

assert state.epoch == last_epoch_to_run
assert engine.should_terminate


def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration():
@pytest.mark.parametrize("data", [None, [1, 2, 3]])
def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(data):
max_epochs = 5
epoch_to_terminate_on = 3
batches_per_epoch = [1, 2, 3]
epoch_length = 3

engine = Engine(MagicMock(return_value=1))

Expand All @@ -115,17 +119,18 @@ def start_of_epoch_handler(engine):

assert not engine.should_terminate

state = engine.run(batches_per_epoch, max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)

# epoch is not completed so counter is not incremented
assert state.epoch == epoch_to_terminate_on
assert engine.should_terminate
# completes first iteration
assert state.iteration == ((epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
assert state.iteration == ((epoch_to_terminate_on - 1) * epoch_length) + 1


def test_terminate_stops_run_mid_epoch():
num_iterations_per_epoch = 10
@pytest.mark.parametrize("data", [None, list(range(10))])
def test_terminate_stops_run_mid_epoch(data):
num_iterations_per_epoch = len(data) if data is not None else 10
iteration_to_stop = num_iterations_per_epoch + 3

engine = Engine(MagicMock(return_value=1))
Expand All @@ -135,14 +140,15 @@ def start_of_iteration_handler(engine):
engine.terminate()

engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=3)
state = engine.run(data, max_epochs=3, epoch_length=num_iterations_per_epoch)
# completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
assert state.iteration == iteration_to_stop
assert state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch) # it starts from 0


def test_terminate_epoch_stops_mid_epoch():
num_iterations_per_epoch = 10
@pytest.mark.parametrize("data", [None, list(range(10))])
def test_terminate_epoch_stops_mid_epoch(data):
num_iterations_per_epoch = len(data) if data is not None else 10
iteration_to_stop = num_iterations_per_epoch + 4

engine = Engine(MagicMock(return_value=1))
Expand All @@ -153,7 +159,7 @@ def start_of_iteration_handler(engine):

max_epochs = 3
engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=num_iterations_per_epoch)
# completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
true_value = num_iterations_per_epoch * (max_epochs - 1) + iteration_to_stop % num_iterations_per_epoch
assert state.iteration == true_value
Expand All @@ -170,10 +176,13 @@ def _create_mock_data_loader(epochs, batches_per_epoch):
return data_loader_manager


def test_iteration_events_are_fired():
@pytest.mark.parametrize("data", [None, "mock_data_loader"])
def test_iteration_events_are_fired(data):
max_epochs = 5
num_batches = 3
data = _create_mock_data_loader(max_epochs, num_batches)
num_batches = epoch_length = 3
if isinstance(data, str) and data == "mock_data_loader":
data = _create_mock_data_loader(max_epochs, num_batches)
epoch_length = None

engine = Engine(MagicMock(return_value=1))

Expand All @@ -187,7 +196,7 @@ def test_iteration_events_are_fired():
mock_manager.attach_mock(iteration_started, "iteration_started")
mock_manager.attach_mock(iteration_complete, "iteration_complete")

engine.run(data, max_epochs=max_epochs)
engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)

assert iteration_started.call_count == num_batches * max_epochs
assert iteration_complete.call_count == num_batches * max_epochs
Expand All @@ -200,7 +209,8 @@ def test_iteration_events_are_fired():
assert mock_manager.mock_calls == expected_calls


def test_last_event_name():
@pytest.mark.parametrize("data", [None, [1, 2]])
def test_last_event_name(data):
engine = Engine(MagicMock(return_value=1))
assert engine.last_event_name is None

Expand All @@ -224,7 +234,8 @@ def _(_engine):
def _(_engine):
assert _engine.last_event_name == Events.EPOCH_COMPLETED

engine.run([0, 1])
epoch_length = 2 if data is None else None
engine.run(data, epoch_length=epoch_length)
assert engine.last_event_name == Events.COMPLETED


Expand Down Expand Up @@ -343,7 +354,6 @@ def test__setup_engine():
engine.state.dataloader = data
engine._setup_engine()
assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10
# assert engine._dataloader_len == len(data)


def test_run_asserts():
Expand Down Expand Up @@ -455,6 +465,7 @@ def _test_run_check_triggered_events():
_test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=100)
_test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=50, exp_iter_stops=50 * 5 // 100)
_test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=150, exp_iter_stops=150 * 5 // 100)
_test_check_triggered_events(None, max_epochs=5, epoch_length=150)


def test_run_check_triggered_events_list():
Expand Down Expand Up @@ -493,7 +504,6 @@ def limited_data_iterator():


def test_run_check_triggered_events_on_iterator():

_test_run_check_triggered_events_on_iterator()


Expand Down Expand Up @@ -986,3 +996,29 @@ def update_fn(engine, batch):
assert len(set(mem_consumption2)) == 2

assert mem_consumption1 == mem_consumption2


def test_engine_no_data_asserts():
trainer = Engine(lambda e, b: None)

with pytest.raises(ValueError, match=r"epoch_length should be provided if data is None"):
trainer.run(max_epochs=10)


def test_engine_no_data():
def train_step(engine, batch):
assert batch is None

trainer = Engine(train_step)
trainer.run(max_epochs=10, epoch_length=10)

assert trainer.state.iteration == 10 * 10
assert trainer.state.epoch == 10
assert trainer.state.dataloader is None

# continue
trainer.run(max_epochs=20)

assert trainer.state.iteration == 20 * 10
assert trainer.state.epoch == 20
assert trainer.state.dataloader is None