Skip to content

Commit 193643c

Browse files
Add new metric usages and update RunningAverage accordingly (#2958)
* Improve the metric, update tests & docstrings * Adding epoch_bound and detach Also a few improvements * Add test for detach and epoch_bound * Fix a bug and do a refactor in test_metric and add test for SingleEpochRunningBatchWise in test_metric * Fix docstrings * autopep8 fix * Improve code, docs and tests * Improve code * Fix mypy * Update test_running_epoch_wise test * Some improvements --------- Co-authored-by: sadra-barikbin <sadra-barikbin@users.noreply.github.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent eba5aae commit 193643c

6 files changed

Lines changed: 501 additions & 356 deletions

File tree

docs/source/metrics.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,10 @@ Complete list of usages
290290

291291
- :class:`~ignite.metrics.metric.MetricUsage`
292292
- :class:`~ignite.metrics.metric.EpochWise`
293+
- :class:`~ignite.metrics.metric.RunningEpochWise`
293294
- :class:`~ignite.metrics.metric.BatchWise`
295+
- :class:`~ignite.metrics.metric.RunningBatchWise`
296+
- :class:`~ignite.metrics.metric.SingleEpochRunningBatchWise`
294297
- :class:`~ignite.metrics.metric.BatchFiltered`
295298

296299
Metrics and distributed computations
@@ -359,10 +362,22 @@ EpochWise
359362
~~~~~~~~~
360363
.. autoclass:: ignite.metrics.metric.EpochWise
361364

365+
RunningEpochWise
366+
~~~~~~~~~~~~~~~~
367+
.. autoclass:: ignite.metrics.metric.RunningEpochWise
368+
362369
BatchWise
363370
~~~~~~~~~
364371
.. autoclass:: ignite.metrics.metric.BatchWise
365372

373+
RunningBatchWise
374+
~~~~~~~~~~~~~~~~
375+
.. autoclass:: ignite.metrics.metric.RunningBatchWise
376+
377+
SingleEpochRunningBatchWise
378+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
379+
.. autoclass:: ignite.metrics.metric.SingleEpochRunningBatchWise
380+
366381
BatchFiltered
367382
~~~~~~~~~~~~~
368383
.. autoclass:: ignite.metrics.metric.BatchFiltered

ignite/contrib/engines/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ignite.handlers.checkpoint import BaseSaveHandler
3434
from ignite.handlers.param_scheduler import ParamScheduler
3535
from ignite.metrics import RunningAverage
36+
from ignite.metrics.metric import RunningBatchWise
3637
from ignite.utils import deprecated
3738

3839

@@ -209,8 +210,8 @@ def output_transform(x: Any, index: int, name: str) -> Any:
209210
)
210211

211212
for i, n in enumerate(output_names):
212-
RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
213-
trainer, n
213+
RunningAverage(output_transform=partial(output_transform, index=i, name=n)).attach(
214+
trainer, n, usage=RunningBatchWise()
214215
)
215216

216217
if with_pbars:

ignite/metrics/metric.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
if TYPE_CHECKING:
1313
from ignite.metrics.metrics_lambda import MetricsLambda
1414

15-
__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"]
15+
__all__ = [
16+
"Metric",
17+
"MetricUsage",
18+
"EpochWise",
19+
"BatchWise",
20+
"BatchFiltered",
21+
"RunningEpochWise",
22+
"RunningBatchWise",
23+
"SingleEpochRunningBatchWise",
24+
]
1625

1726

1827
class MetricUsage:
@@ -31,6 +40,8 @@ class MetricUsage:
3140
:meth:`~ignite.metrics.metric.Metric.iteration_completed`.
3241
"""
3342

43+
usage_name: str
44+
3445
def __init__(self, started: Events, completed: Events, iteration_completed: CallableEventWithFilter) -> None:
3546
self.__started = started
3647
self.__completed = completed
@@ -74,6 +85,33 @@ def __init__(self) -> None:
7485
)
7586

7687

88+
class RunningEpochWise(EpochWise):
89+
"""
90+
Running epoch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric
91+
usage. A metric with such a usage most likely accompanies an :class:`~.metrics.metric.EpochWise` one to compute
92+
a running measure of it e.g. running average.
93+
94+
Metric's methods are triggered on the following engine events:
95+
96+
- :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED``
97+
(See :class:`~ignite.engine.events.Events`).
98+
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``EPOCH_COMPLETED``.
99+
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``EPOCH_COMPLETED``.
100+
101+
Attributes:
102+
usage_name: usage name string
103+
"""
104+
105+
usage_name: str = "running_epoch_wise"
106+
107+
def __init__(self) -> None:
108+
super(EpochWise, self).__init__(
109+
started=Events.STARTED,
110+
completed=Events.EPOCH_COMPLETED,
111+
iteration_completed=Events.EPOCH_COMPLETED,
112+
)
113+
114+
77115
class BatchWise(MetricUsage):
78116
"""
79117
Batch-wise usage of Metrics.
@@ -99,6 +137,59 @@ def __init__(self) -> None:
99137
)
100138

101139

140+
class RunningBatchWise(BatchWise):
141+
"""
142+
Running batch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric
143+
usage. A metric with such a usage could for example accompany a :class:`~.metrics.metric.BatchWise` one to compute
144+
a running measure of it e.g. running average.
145+
146+
Metric's methods are triggered on the following engine events:
147+
148+
- :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED``
149+
(See :class:`~ignite.engine.events.Events`).
150+
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
151+
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``.
152+
153+
Attributes:
154+
usage_name: usage name string
155+
"""
156+
157+
usage_name: str = "running_batch_wise"
158+
159+
def __init__(self) -> None:
160+
super(BatchWise, self).__init__(
161+
started=Events.STARTED,
162+
completed=Events.ITERATION_COMPLETED,
163+
iteration_completed=Events.ITERATION_COMPLETED,
164+
)
165+
166+
167+
class SingleEpochRunningBatchWise(BatchWise):
168+
"""
169+
Running batch-wise usage of Metrics in a single epoch. It's like :class:`~.metrics.metric.RunningBatchWise` metric
170+
usage with the difference that is used during a single epoch.
171+
172+
Metric's methods are triggered on the following engine events:
173+
174+
- :meth:`~ignite.metrics.metric.Metric.started` on every ``EPOCH_STARTED``
175+
(See :class:`~ignite.engine.events.Events`).
176+
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
177+
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``.
178+
179+
Attributes:
180+
usage_name: usage name string
181+
"""
182+
183+
usage_name: str = "single_epoch_running_batch_wise"
184+
185+
def __init__(self) -> None:
186+
super(BatchWise, self).__init__(
187+
started=Events.EPOCH_STARTED,
188+
completed=Events.ITERATION_COMPLETED,
189+
iteration_completed=Events.ITERATION_COMPLETED,
190+
)
191+
192+
102193
class BatchFiltered(MetricUsage):
103194
"""
104195
Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.
@@ -344,12 +435,16 @@ def completed(self, engine: Engine, name: str) -> None:
344435

345436
def _check_usage(self, usage: Union[str, MetricUsage]) -> MetricUsage:
346437
if isinstance(usage, str):
347-
if usage == EpochWise.usage_name:
348-
usage = EpochWise()
349-
elif usage == BatchWise.usage_name:
350-
usage = BatchWise()
351-
else:
352-
raise ValueError(f"usage should be 'EpochWise.usage_name' or 'BatchWise.usage_name', get {usage}")
438+
usages = [EpochWise, RunningEpochWise, BatchWise, RunningBatchWise, SingleEpochRunningBatchWise]
439+
for usage_cls in usages:
440+
if usage == usage_cls.usage_name:
441+
usage = usage_cls()
442+
break
443+
if not isinstance(usage, MetricUsage):
444+
raise ValueError(
445+
"Argument usage should be '(Running)EpochWise.usage_name' or "
446+
f"'((SingleEpoch)Running)BatchWise.usage_name', got {usage}"
447+
)
353448
if not isinstance(usage, MetricUsage):
354449
raise TypeError(f"Unhandled usage type {type(usage)}")
355450
return usage

ignite/metrics/running_average.py

Lines changed: 101 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Callable, cast, Optional, Sequence, Union
1+
import warnings
2+
from typing import Any, Callable, cast, Optional, Union
23

34
import torch
45

56
import ignite.distributed as idist
67
from ignite.engine import Engine, Events
7-
from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce
8+
from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, RunningBatchWise, SingleEpochRunningBatchWise
89

910
__all__ = ["RunningAverage"]
1011

@@ -18,8 +19,10 @@ class RunningAverage(Metric):
1819
alpha: running average decay factor, default 0.98
1920
output_transform: a function to use to transform the output if `src` is None and
2021
corresponds the output of process function. Otherwise it should be None.
21-
epoch_bound: whether the running average should be reset after each epoch (defaults
22-
to True).
22+
epoch_bound: whether the running average should be reset after each epoch. It is depracated in favor of
23+
``usage`` argument in :meth:`attach` method. Setting ``epoch_bound`` to ``False`` is equivalent to
24+
``usage=SingleEpochRunningBatchWise()`` and setting it to ``True`` is equivalent to
25+
``usage=RunningBatchWise()`` in the :meth:`attach` method. Default None.
2326
device: specifies which device updates are accumulated on. Should be
2427
None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will
2528
use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
@@ -90,7 +93,7 @@ def __init__(
9093
src: Optional[Metric] = None,
9194
alpha: float = 0.98,
9295
output_transform: Optional[Callable] = None,
93-
epoch_bound: bool = True,
96+
epoch_bound: Optional[bool] = None,
9497
device: Optional[Union[str, torch.device]] = None,
9598
):
9699
if not (isinstance(src, Metric) or src is None):
@@ -101,70 +104,119 @@ def __init__(
101104
if isinstance(src, Metric):
102105
if output_transform is not None:
103106
raise ValueError("Argument output_transform should be None if src is a Metric.")
107+
108+
def output_transform(x: Any) -> Any:
109+
return x
110+
104111
if device is not None:
105112
raise ValueError("Argument device should be None if src is a Metric.")
106-
self.src = src
107-
self._get_src_value = self._get_metric_value
108-
setattr(self, "iteration_completed", self._metric_iteration_completed)
113+
self.src: Union[Metric, None] = src
109114
device = src._device
110115
else:
111116
if output_transform is None:
112117
raise ValueError(
113118
"Argument output_transform should not be None if src corresponds "
114119
"to the output of process function."
115120
)
116-
self._get_src_value = self._get_output_value
117-
setattr(self, "update", self._output_update)
121+
self.src = None
118122
if device is None:
119123
device = torch.device("cpu")
120124

121-
self.alpha = alpha
125+
if epoch_bound is not None:
126+
warnings.warn(
127+
"`epoch_bound` is deprecated and will be removed in the future. Consider using `usage` argument of"
128+
"`attach` method instead. `epoch_bound=True` is equivalent with `usage=SingleEpochRunningBatchWise()`"
129+
" and `epoch_bound=False` is equivalent with `usage=RunningBatchWise()`."
130+
)
122131
self.epoch_bound = epoch_bound
123-
super(RunningAverage, self).__init__(output_transform=output_transform, device=device) # type: ignore[arg-type]
132+
self.alpha = alpha
133+
super(RunningAverage, self).__init__(output_transform=output_transform, device=device)
124134

125135
@reinit__is_reduced
126136
def reset(self) -> None:
127137
self._value: Optional[Union[float, torch.Tensor]] = None
138+
if isinstance(self.src, Metric):
139+
self.src.reset()
128140

129141
@reinit__is_reduced
130-
def update(self, output: Sequence) -> None:
131-
# Implement abstract method
132-
pass
133-
134-
def compute(self) -> Union[torch.Tensor, float]:
135-
if self._value is None:
136-
self._value = self._get_src_value()
142+
def update(self, output: Union[torch.Tensor, float]) -> None:
143+
if self.src is None:
144+
output = output.detach().to(self._device, copy=True) if isinstance(output, torch.Tensor) else output
145+
value = idist.all_reduce(output) / idist.get_world_size()
137146
else:
138-
self._value = self._value * self.alpha + (1.0 - self.alpha) * self._get_src_value()
147+
value = self.src.compute()
148+
self.src.reset()
139149

140-
return self._value
141-
142-
def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = EpochWise()) -> None:
143-
if self.epoch_bound:
144-
# restart average every epoch
145-
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
150+
if self._value is None:
151+
self._value = value
146152
else:
147-
engine.add_event_handler(Events.STARTED, self.started)
148-
# compute metric
149-
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
150-
# apply running average
151-
engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name)
152-
153-
def _get_metric_value(self) -> Union[torch.Tensor, float]:
154-
return self.src.compute()
155-
156-
@sync_all_reduce("src")
157-
def _get_output_value(self) -> Union[torch.Tensor, float]:
158-
# we need to compute average instead of sum produced by @sync_all_reduce("src")
159-
output = cast(Union[torch.Tensor, float], self.src) / idist.get_world_size()
160-
return output
153+
self._value = self._value * self.alpha + (1.0 - self.alpha) * value
161154

162-
def _metric_iteration_completed(self, engine: Engine) -> None:
163-
self.src.started(engine)
164-
self.src.iteration_completed(engine)
165-
166-
@reinit__is_reduced
167-
def _output_update(self, output: Union[torch.Tensor, float]) -> None:
168-
if isinstance(output, torch.Tensor):
169-
output = output.detach().to(self._device, copy=True)
170-
self.src = output # type: ignore[assignment]
155+
def compute(self) -> Union[torch.Tensor, float]:
156+
return cast(Union[torch.Tensor, float], self._value)
157+
158+
def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
159+
r"""
160+
Attach the metric to the ``engine`` using the events determined by the ``usage``.
161+
162+
Args:
163+
engine: the engine to get attached to.
164+
name: by which, the metric is inserted into ``engine.state.metrics`` dictionary.
165+
usage: the usage determining on which events the metric is reset, updated and computed. It should be an
166+
instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table.
167+
168+
======================================================= ===========================================
169+
``usage`` **class** **Description**
170+
======================================================= ===========================================
171+
:class:`~.metrics.metric.RunningBatchWise` Running average of the ``src`` metric or
172+
``engine.state.output`` is computed across
173+
batches. In the former case, on each batch,
174+
``src`` is reset, updated and computed then
175+
its value is retrieved. Default.
176+
:class:`~.metrics.metric.SingleEpochRunningBatchWise` Same as above but the running average is
177+
computed across batches in an epoch so it
178+
is reset at the end of the epoch.
179+
:class:`~.metrics.metric.RunningEpochWise` Running average of the ``src`` metric or
180+
``engine.state.output`` is computed across
181+
epochs. In the former case, ``src`` works
182+
as if it was attached in a
183+
:class:`~ignite.metrics.metric.EpochWise`
184+
manner and its computed value is retrieved
185+
at the end of the epoch. The latter case
186+
doesn't make much sense for this usage as
187+
the ``engine.state.output`` of the last
188+
batch is retrieved then.
189+
======================================================= ===========================================
190+
191+
``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not
192+
given and it's computed and updated using ``src``, by manually calling its ``compute`` method, or
193+
``engine.state.output`` at ``usage.COMPLETED`` event.
194+
Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by
195+
``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``,
196+
otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``.
197+
198+
.. versionchanged:: 0.5.1
199+
Added `usage` argument
200+
"""
201+
usage = self._check_usage(usage)
202+
if self.epoch_bound is not None:
203+
usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()
204+
205+
if isinstance(self.src, Metric) and not engine.has_event_handler(
206+
self.src.iteration_completed, Events.ITERATION_COMPLETED
207+
):
208+
engine.add_event_handler(Events.ITERATION_COMPLETED, self.src.iteration_completed)
209+
210+
super().attach(engine, name, usage)
211+
212+
def detach(self, engine: Engine, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
213+
usage = self._check_usage(usage)
214+
if self.epoch_bound is not None:
215+
usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()
216+
217+
if isinstance(self.src, Metric) and engine.has_event_handler(
218+
self.src.iteration_completed, Events.ITERATION_COMPLETED
219+
):
220+
engine.remove_event_handler(self.src.iteration_completed, Events.ITERATION_COMPLETED)
221+
222+
super().detach(engine, usage)

0 commit comments

Comments
 (0)