Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ Attaching an event handler is simple using method :meth:`~ignite.engine.Engine.a

trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)

Event handlers can be detached via :meth:`~ignite.engine.Engine.remove_event_handler` or via the :class:`~ignite.engine.RemovableEventHandler`
reference returned by :meth:`~ignite.engine.Engine.add_event_handler`. This can be used to reuse a configured engine for multiple loops:

.. code-block:: python

model = ...
train_loader, validation_loader, test_loader = ...

trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={'acc': Accuracy()})

def log_metrics(engine, title):
print("Epoch: {} - {} accuracy: {:.2f}"
.format(trainer.state.epoch, title, engine.state.metrics['acc']))

@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):
evaluator.run(train_loader)

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):
evaluator.run(validation_loader)

with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):
evaluator.run(test_loader)

trainer.run(train_loader, max_epochs=100)

.. Note ::

Expand Down
4 changes: 4 additions & 0 deletions docs/source/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ ignite.engine
:undoc-members:

.. autoclass:: State

.. autoclass:: RemovableEventHandler
:members:
:undoc-members:
58 changes: 58 additions & 0 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from collections import defaultdict
from enum import Enum
import weakref

from ignite._utils import _to_hours_mins_secs

Expand Down Expand Up @@ -48,6 +49,58 @@ def get_event_attrib_value(self, event_name):
return getattr(self, State.event_to_attr[event_name])


class RemovableEventHandle(object):
"""A weakref handle to remove a registered event.

A handle that may be used to remove a registered event handler via the
remove method, with-statement, or context manager protocol. Returned from
:meth:`~ignite.engine.Engine.add_event_handler`.


Args:
event_name: Registered event name.
handler: Registered event handler, stored as weakref.
engine: Target engine, stored as weakref.

Example usage:

.. code-block:: python

engine = Engine()

def print_epoch(engine):
print("Epoch: {}".format(engine.state.epoch))

with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch):
# print_epoch handler registered for a single run
engine.run(data)

# print_epoch handler is now unregistered
"""

def __init__(self, event_name, handler, engine):
self.event_name = event_name
self.handler = weakref.ref(handler)
self.engine = weakref.ref(engine)

def remove(self):
"""Remove handler from engine."""
handler = self.handler()
engine = self.engine()

if handler is None or engine is None:
return

if engine.has_event_handler(handler, self.event_name):
engine.remove_event_handler(handler, self.event_name)

def __enter__(self):
return self

def __exit__(self, type, value, tb):
self.remove()


class Engine(object):
"""Runs a given process_function over each batch of a dataset, emitting events as it goes.

Expand Down Expand Up @@ -164,6 +217,9 @@ def add_event_handler(self, event_name, handler, *args, **kwargs):
Note that other arguments can be passed to the handler in addition to the `*args` and `**kwargs`
passed here, for example during :attr:`~ignite.engine.Events.EXCEPTION_RAISED`.

Returns:
:class:`~ignite.engine.RemovableEventHandler`, which can be used to remove the handler.

Example usage:

.. code-block:: python
Expand All @@ -186,6 +242,8 @@ def print_epoch(engine):
self._event_handlers[event_name].append((handler, args, kwargs))
self._logger.debug("added handler for event %s.", event_name)

return RemovableEventHandle(event_name, handler, self)

def has_event_handler(self, handler, event_name=None):
"""Check if the specified event has the specified handler.

Expand Down
80 changes: 80 additions & 0 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import division
from enum import Enum
import gc

import pytest
from mock import call, MagicMock, Mock
Expand Down Expand Up @@ -130,6 +131,85 @@ def test_adding_multiple_event_handlers():
handler.assert_called_once_with(engine)


def test_event_removable_handle():

# Removable handle removes event from engine.
engine = DummyEngine()
handler = MagicMock()

removable_handle = engine.add_event_handler(Events.STARTED, handler)
assert engine.has_event_handler(handler, Events.STARTED)

engine.run(1)
handler.assert_called_once_with(engine)

removable_handle.remove()
assert not engine.has_event_handler(handler, Events.STARTED)

# Second engine pass does not fire handle again.
engine.run(1)
handler.assert_called_once_with(engine)

# Removable handle can be used as a context manager
handler = MagicMock()

with engine.add_event_handler(Events.STARTED, handler):
assert engine.has_event_handler(handler, Events.STARTED)
engine.run(1)

assert not engine.has_event_handler(handler, Events.STARTED)
handler.assert_called_once_with(engine)

engine.run(1)
handler.assert_called_once_with(engine)

# Removeable handle only effects a single event registration
handler = MagicMock()

with engine.add_event_handler(Events.STARTED, handler):
with engine.add_event_handler(Events.COMPLETED, handler):
assert engine.has_event_handler(handler, Events.STARTED)
assert engine.has_event_handler(handler, Events.COMPLETED)
assert engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.COMPLETED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.COMPLETED)

# Removeable handle is re-enter and re-exitable

handler = MagicMock()

remove = engine.add_event_handler(Events.STARTED, handler)

with remove:
with remove:
assert engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.STARTED)
assert not engine.has_event_handler(handler, Events.STARTED)

# Removeable handle is a weakref, does not keep engine or event alive
def _add_in_closure():
_engine = DummyEngine()

def _handler(_):
pass

_handle = _engine.add_event_handler(Events.STARTED, _handler)
assert _handle.engine() is _engine
assert _handle.handler() is _handler

return _handle

removable_handle = _add_in_closure()

# gc.collect, resolving reference cycles in engine/state
# required to ensure object deletion in python2
gc.collect()

assert removable_handle.engine() is None
assert removable_handle.handler() is None


def test_has_event_handler():
engine = DummyEngine()
handlers = [MagicMock(), MagicMock()]
Expand Down