Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ignite/engine/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def _init_run(self) -> None:

def _setup_engine(self) -> None:
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
raise ValueError(
"Deterministic engine does not support the option of data=None. Please, provide data as iterable"
)

self._dataloader_len = self._get_data_length(self.state.dataloader)
Expand Down
47 changes: 31 additions & 16 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,12 @@ def switch_dataloader():

def run(
self,
data: Iterable,
data: Optional[Iterable] = None,
max_epochs: Optional[int] = None,
max_iters: Optional[int] = None,
epoch_length: Optional[int] = None,
) -> State:
"""Runs the `process_function` over the passed data.
"""Runs the ``process_function`` over the passed data.

Engine has a state and the following logic is applied in this function:

Expand All @@ -617,7 +617,8 @@ def run(
- If state is defined, engine is NOT "done", then input arguments if provided override defined state.

Args:
data: Collection of batches allowing repeated iteration (e.g., list or `DataLoader`).
data: Collection of batches allowing repeated iteration (e.g., list or `DataLoader`). If not provided, then
``epoch_length`` is required and ``batch`` argument of ``process_function`` will be ``None``.
max_epochs: Max epochs to run for (default: None).
If a new state should be created (first run or run again from ended engine), it's default value is 1.
If run is resuming from a state, provided `max_epochs` will be taken into account and should be larger
Expand Down Expand Up @@ -656,7 +657,7 @@ def switch_batch(engine):
trainer.run(train_loader, max_epochs=2)

"""
if not isinstance(data, Iterable):
if data is not None and not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")

if self.state.max_epochs is not None:
Expand All @@ -680,6 +681,9 @@ def switch_batch(engine):
if self.state.max_epochs is None or self._is_done(self.state):
# Create new state
if epoch_length is None:
if data is None:
raise ValueError("epoch_length should be provided if data is None")

epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
Expand Down Expand Up @@ -707,6 +711,8 @@ def switch_batch(engine):
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)
if self.state.epoch_length is None and data is None:
raise ValueError("epoch_length should be provided if data is None")

self.state.dataloader = data
return self._internal_run()
Expand All @@ -725,14 +731,20 @@ def _get_data_length(self, data: Iterable) -> Optional[int]:
pass
return None

def _setup_engine(self) -> None:
def _setup_dataloader_iter(self) -> None:
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
)
if self.state.epoch_length is None:
raise RuntimeError(
"Internal error, self.state.epoch_length is None. "
"Please, file an issue if you encounter this error."
)
self._dataloader_iter = _get_none_data_iter(self.state.epoch_length)
else:
self._dataloader_iter = iter(self.state.dataloader)

def _setup_engine(self) -> None:
self._setup_dataloader_iter()
iteration = self.state.iteration
self._dataloader_iter = iter(self.state.dataloader)

# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
if self.state.epoch_length is not None:
Expand Down Expand Up @@ -796,11 +808,8 @@ def _run_once_on_dataset(self) -> float:
try:
if self._dataloader_iter is None:
raise RuntimeError(
"Internal error, self._dataloader_iter is None. Please, file an issue if you encounter this error."
)
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
"Internal error, self._dataloader_iter is None. "
"Please, file an issue if you encounter this error."
)

while True:
Expand Down Expand Up @@ -839,7 +848,7 @@ def _run_once_on_dataset(self) -> float:
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
self.set_data(self.state.dataloader)
self._setup_dataloader_iter()

should_exit = True

Expand All @@ -853,7 +862,7 @@ def _run_once_on_dataset(self) -> float:
if self.should_terminate or self.should_terminate_single_epoch:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self.set_data(self.state.dataloader)
self._setup_dataloader_iter()
break

if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
Expand All @@ -868,3 +877,9 @@ def _run_once_on_dataset(self) -> float:
self._handle_exception(e)

return time.time() - start_time


def _get_none_data_iter(size: int) -> Iterator:
# Sized iterator for data as None
for _ in range(size):
yield None
7 changes: 7 additions & 0 deletions tests/ignite/engine/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,10 @@ def test__set_rng_states_cuda():
rng_states = [random.getstate(), torch.get_rng_state().cuda(), np.random.get_state()]
_set_rng_states(rng_states)
assert rng_states[1].device.type == "cpu"


def test_engine_no_data_asserts():
trainer = DeterministicEngine(lambda e, b: None)

with pytest.raises(ValueError, match=r"Deterministic engine does not support the option of data=None"):
trainer.run(max_epochs=10, epoch_length=10)
94 changes: 65 additions & 29 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,35 @@ def data():
engine.run(data)


def test_current_epoch_counter_increases_every_epoch():
@pytest.mark.parametrize("data", [None, [1, 2]])
def test_current_epoch_counter_increases_every_epoch(data):
engine = Engine(MagicMock(return_value=1))
max_epochs = 5

counter = EpochCounter()
engine.add_event_handler(Events.EPOCH_STARTED, counter)

state = engine.run([1, 2], max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=2)
assert state.epoch == max_epochs
counter.current_epoch_count = 1
state = engine.run([1, 2], max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=2)
assert state.epoch == max_epochs


def test_current_iteration_counter_increases_every_iteration():
batches = [1, 2, 3]
@pytest.mark.parametrize("data", [None, [1, 2, 3]])
def test_current_iteration_counter_increases_every_iteration(data):
engine = Engine(MagicMock(return_value=1))
max_epochs = 5

counter = IterationCounter()
engine.add_event_handler(Events.ITERATION_STARTED, counter)

state = engine.run(batches, max_epochs=max_epochs)
assert state.iteration == max_epochs * len(batches)
epoch_length = 3
state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)
assert state.iteration == max_epochs * epoch_length
counter.current_iteration_count = 1
state = engine.run(batches, max_epochs=max_epochs)
assert state.iteration == max_epochs * len(batches)
state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)
assert state.iteration == max_epochs * epoch_length


def test_stopping_criterion_is_max_epochs():
Expand All @@ -80,7 +82,8 @@ def test_stopping_criterion_is_max_epochs():
assert state.epoch == max_epochs


def test_terminate_at_end_of_epoch_stops_run():
@pytest.mark.parametrize("data", [None, [1, 2]])
def test_terminate_at_end_of_epoch_stops_run(data):
max_epochs = 5
last_epoch_to_run = 3

Expand All @@ -94,16 +97,17 @@ def end_of_epoch_handler(engine):

assert not engine.should_terminate

state = engine.run([1], max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=2)

assert state.epoch == last_epoch_to_run
assert engine.should_terminate


def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration():
@pytest.mark.parametrize("data", [None, [1, 2, 3]])
def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(data):
max_epochs = 5
epoch_to_terminate_on = 3
batches_per_epoch = [1, 2, 3]
epoch_length = 3

engine = Engine(MagicMock(return_value=1))

Expand All @@ -115,17 +119,18 @@ def start_of_epoch_handler(engine):

assert not engine.should_terminate

state = engine.run(batches_per_epoch, max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)

# epoch is not completed so counter is not incremented
assert state.epoch == epoch_to_terminate_on
assert engine.should_terminate
# completes first iteration
assert state.iteration == ((epoch_to_terminate_on - 1) * len(batches_per_epoch)) + 1
assert state.iteration == ((epoch_to_terminate_on - 1) * epoch_length) + 1


def test_terminate_stops_run_mid_epoch():
num_iterations_per_epoch = 10
@pytest.mark.parametrize("data", [None, list(range(10))])
def test_terminate_stops_run_mid_epoch(data):
num_iterations_per_epoch = len(data) if data is not None else 10
iteration_to_stop = num_iterations_per_epoch + 3

engine = Engine(MagicMock(return_value=1))
Expand All @@ -135,14 +140,15 @@ def start_of_iteration_handler(engine):
engine.terminate()

engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=3)
state = engine.run(data, max_epochs=3, epoch_length=num_iterations_per_epoch)
# completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
assert state.iteration == iteration_to_stop
assert state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch) # it starts from 0


def test_terminate_epoch_stops_mid_epoch():
num_iterations_per_epoch = 10
@pytest.mark.parametrize("data", [None, list(range(10))])
def test_terminate_epoch_stops_mid_epoch(data):
num_iterations_per_epoch = len(data) if data is not None else 10
iteration_to_stop = num_iterations_per_epoch + 4

engine = Engine(MagicMock(return_value=1))
Expand All @@ -153,7 +159,7 @@ def start_of_iteration_handler(engine):

max_epochs = 3
engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=max_epochs)
state = engine.run(data, max_epochs=max_epochs, epoch_length=num_iterations_per_epoch)
# completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
true_value = num_iterations_per_epoch * (max_epochs - 1) + iteration_to_stop % num_iterations_per_epoch
assert state.iteration == true_value
Expand All @@ -170,10 +176,13 @@ def _create_mock_data_loader(epochs, batches_per_epoch):
return data_loader_manager


def test_iteration_events_are_fired():
@pytest.mark.parametrize("data", [None, "mock_data_loader"])
def test_iteration_events_are_fired(data):
max_epochs = 5
num_batches = 3
data = _create_mock_data_loader(max_epochs, num_batches)
num_batches = epoch_length = 3
if isinstance(data, str) and data == "mock_data_loader":
data = _create_mock_data_loader(max_epochs, num_batches)
epoch_length = None

engine = Engine(MagicMock(return_value=1))

Expand All @@ -187,7 +196,7 @@ def test_iteration_events_are_fired():
mock_manager.attach_mock(iteration_started, "iteration_started")
mock_manager.attach_mock(iteration_complete, "iteration_complete")

engine.run(data, max_epochs=max_epochs)
engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)

assert iteration_started.call_count == num_batches * max_epochs
assert iteration_complete.call_count == num_batches * max_epochs
Expand All @@ -200,7 +209,8 @@ def test_iteration_events_are_fired():
assert mock_manager.mock_calls == expected_calls


def test_last_event_name():
@pytest.mark.parametrize("data", [None, [1, 2]])
def test_last_event_name(data):
engine = Engine(MagicMock(return_value=1))
assert engine.last_event_name is None

Expand All @@ -224,7 +234,8 @@ def _(_engine):
def _(_engine):
assert _engine.last_event_name == Events.EPOCH_COMPLETED

engine.run([0, 1])
epoch_length = 2 if data is None else None
engine.run(data, epoch_length=epoch_length)
assert engine.last_event_name == Events.COMPLETED


Expand Down Expand Up @@ -343,7 +354,6 @@ def test__setup_engine():
engine.state.dataloader = data
engine._setup_engine()
assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10
# assert engine._dataloader_len == len(data)


def test_run_asserts():
Expand Down Expand Up @@ -455,6 +465,7 @@ def _test_run_check_triggered_events():
_test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=100)
_test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=50, exp_iter_stops=50 * 5 // 100)
_test_check_triggered_events(list(range(100)), max_epochs=5, epoch_length=150, exp_iter_stops=150 * 5 // 100)
_test_check_triggered_events(None, max_epochs=5, epoch_length=150)


def test_run_check_triggered_events_list():
Expand Down Expand Up @@ -493,7 +504,6 @@ def limited_data_iterator():


def test_run_check_triggered_events_on_iterator():

_test_run_check_triggered_events_on_iterator()


Expand Down Expand Up @@ -986,3 +996,29 @@ def update_fn(engine, batch):
assert len(set(mem_consumption2)) == 2

assert mem_consumption1 == mem_consumption2


def test_engine_no_data_asserts():
trainer = Engine(lambda e, b: None)

with pytest.raises(ValueError, match=r"epoch_length should be provided if data is None"):
trainer.run(max_epochs=10)


def test_engine_no_data():
def train_step(engine, batch):
assert batch is None

trainer = Engine(train_step)
trainer.run(max_epochs=10, epoch_length=10)

assert trainer.state.iteration == 10 * 10
assert trainer.state.epoch == 10
assert trainer.state.dataloader is None

# continue
trainer.run(max_epochs=20)

assert trainer.state.iteration == 20 * 10
assert trainer.state.epoch == 20
assert trainer.state.dataloader is None