Skip to content

Commit 01f2f5f

Browse files
Alex Fordvfdev-5
authored andcommitted
Return removable handle from Engine.add_event_handler(). (#588)
* Add tests for event removable handle. Add feature tests for engine.add_event_handler returning removable event handles. * Return RemovableEventHandle from Engine.add_event_handler. * Fixup removable event handle test in python 2.7. Explicitly trigger gc, allowing cycle detection between engine and state, in removable handle weakref test. Python 2.7 cycle detection appears to be less aggressive than python 3+. * Add removable event handler docs. Add autodoc configuration for RemovableEventHandler, expand "concepts" documentation with event remove example following event add example. * Update concepts.rst
1 parent c5e6c70 commit 01f2f5f

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

docs/source/concepts.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,33 @@ Attaching an event handler is simple using method :meth:`~ignite.engine.Engine.a
8888
8989
trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
9090
91+
Event handlers can be detached via :meth:`~ignite.engine.Engine.remove_event_handler` or via the :class:`~ignite.engine.RemovableEventHandler`
92+
reference returned by :meth:`~ignite.engine.Engine.add_event_handler`. This can be used to reuse a configured engine for multiple loops:
93+
94+
.. code-block:: python
95+
96+
model = ...
97+
train_loader, validation_loader, test_loader = ...
98+
99+
trainer = create_supervised_trainer(model, optimizer, loss)
100+
evaluator = create_supervised_evaluator(model, metrics={'acc': Accuracy()})
101+
102+
def log_metrics(engine, title):
103+
print("Epoch: {} - {} accuracy: {:.2f}"
104+
.format(trainer.state.epoch, title, engine.state.metrics['acc']))
105+
106+
@trainer.on(Events.EPOCH_COMPLETED)
107+
def evaluate(trainer):
108+
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):
109+
evaluator.run(train_loader)
110+
111+
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):
112+
evaluator.run(validation_loader)
113+
114+
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):
115+
evaluator.run(test_loader)
116+
117+
trainer.run(train_loader, max_epochs=100)
91118
92119
.. Note ::
93120

docs/source/engine.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ ignite.engine
1414
:undoc-members:
1515

1616
.. autoclass:: State
17+
18+
.. autoclass:: RemovableEventHandler
19+
:members:
20+
:undoc-members:

ignite/engine/engine.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from collections import defaultdict
66
from enum import Enum
7+
import weakref
78

89
from ignite._utils import _to_hours_mins_secs
910

@@ -48,6 +49,58 @@ def get_event_attrib_value(self, event_name):
4849
return getattr(self, State.event_to_attr[event_name])
4950

5051

52+
class RemovableEventHandle(object):
53+
"""A weakref handle to remove a registered event.
54+
55+
A handle that may be used to remove a registered event handler via the
56+
remove method, with-statement, or context manager protocol. Returned from
57+
:meth:`~ignite.engine.Engine.add_event_handler`.
58+
59+
60+
Args:
61+
event_name: Registered event name.
62+
handler: Registered event handler, stored as weakref.
63+
engine: Target engine, stored as weakref.
64+
65+
Example usage:
66+
67+
.. code-block:: python
68+
69+
engine = Engine()
70+
71+
def print_epoch(engine):
72+
print("Epoch: {}".format(engine.state.epoch))
73+
74+
with engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch):
75+
# print_epoch handler registered for a single run
76+
engine.run(data)
77+
78+
# print_epoch handler is now unregistered
79+
"""
80+
81+
def __init__(self, event_name, handler, engine):
82+
self.event_name = event_name
83+
self.handler = weakref.ref(handler)
84+
self.engine = weakref.ref(engine)
85+
86+
def remove(self):
87+
"""Remove handler from engine."""
88+
handler = self.handler()
89+
engine = self.engine()
90+
91+
if handler is None or engine is None:
92+
return
93+
94+
if engine.has_event_handler(handler, self.event_name):
95+
engine.remove_event_handler(handler, self.event_name)
96+
97+
def __enter__(self):
98+
return self
99+
100+
def __exit__(self, type, value, tb):
101+
self.remove()
102+
103+
51104
class Engine(object):
52105
"""Runs a given process_function over each batch of a dataset, emitting events as it goes.
53106
@@ -164,6 +217,9 @@ def add_event_handler(self, event_name, handler, *args, **kwargs):
164217
Note that other arguments can be passed to the handler in addition to the `*args` and `**kwargs`
165218
passed here, for example during :attr:`~ignite.engine.Events.EXCEPTION_RAISED`.
166219
220+
Returns:
221+
:class:`~ignite.engine.RemovableEventHandler`, which can be used to remove the handler.
222+
167223
Example usage:
168224
169225
.. code-block:: python
@@ -186,6 +242,8 @@ def print_epoch(engine):
186242
self._event_handlers[event_name].append((handler, args, kwargs))
187243
self._logger.debug("added handler for event %s.", event_name)
188244

245+
return RemovableEventHandle(event_name, handler, self)
246+
189247
def has_event_handler(self, handler, event_name=None):
190248
"""Check if the specified event has the specified handler.
191249

tests/ignite/engine/test_engine.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import division
22
from enum import Enum
3+
import gc
34

45
import pytest
56
from mock import call, MagicMock, Mock
@@ -130,6 +131,85 @@ def test_adding_multiple_event_handlers():
130131
handler.assert_called_once_with(engine)
131132

132133

134+
def test_event_removable_handle():
135+
136+
# Removable handle removes event from engine.
137+
engine = DummyEngine()
138+
handler = MagicMock()
139+
140+
removable_handle = engine.add_event_handler(Events.STARTED, handler)
141+
assert engine.has_event_handler(handler, Events.STARTED)
142+
143+
engine.run(1)
144+
handler.assert_called_once_with(engine)
145+
146+
removable_handle.remove()
147+
assert not engine.has_event_handler(handler, Events.STARTED)
148+
149+
# Second engine pass does not fire handle again.
150+
engine.run(1)
151+
handler.assert_called_once_with(engine)
152+
153+
# Removable handle can be used as a context manager
154+
handler = MagicMock()
155+
156+
with engine.add_event_handler(Events.STARTED, handler):
157+
assert engine.has_event_handler(handler, Events.STARTED)
158+
engine.run(1)
159+
160+
assert not engine.has_event_handler(handler, Events.STARTED)
161+
handler.assert_called_once_with(engine)
162+
163+
engine.run(1)
164+
handler.assert_called_once_with(engine)
165+
166+
# Removeable handle only effects a single event registration
167+
handler = MagicMock()
168+
169+
with engine.add_event_handler(Events.STARTED, handler):
170+
with engine.add_event_handler(Events.COMPLETED, handler):
171+
assert engine.has_event_handler(handler, Events.STARTED)
172+
assert engine.has_event_handler(handler, Events.COMPLETED)
173+
assert engine.has_event_handler(handler, Events.STARTED)
174+
assert not engine.has_event_handler(handler, Events.COMPLETED)
175+
assert not engine.has_event_handler(handler, Events.STARTED)
176+
assert not engine.has_event_handler(handler, Events.COMPLETED)
177+
178+
# Removeable handle is re-enter and re-exitable
179+
180+
handler = MagicMock()
181+
182+
remove = engine.add_event_handler(Events.STARTED, handler)
183+
184+
with remove:
185+
with remove:
186+
assert engine.has_event_handler(handler, Events.STARTED)
187+
assert not engine.has_event_handler(handler, Events.STARTED)
188+
assert not engine.has_event_handler(handler, Events.STARTED)
189+
190+
# Removeable handle is a weakref, does not keep engine or event alive
191+
def _add_in_closure():
192+
_engine = DummyEngine()
193+
194+
def _handler(_):
195+
pass
196+
197+
_handle = _engine.add_event_handler(Events.STARTED, _handler)
198+
assert _handle.engine() is _engine
199+
assert _handle.handler() is _handler
200+
201+
return _handle
202+
203+
removable_handle = _add_in_closure()
204+
205+
# gc.collect, resolving reference cycles in engine/state
206+
# required to ensure object deletion in python2
207+
gc.collect()
208+
209+
assert removable_handle.engine() is None
210+
assert removable_handle.handler() is None
211+
212+
133213
def test_has_event_handler():
134214
engine = DummyEngine()
135215
handlers = [MagicMock(), MagicMock()]

0 commit comments

Comments
 (0)