Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ Complete list of usages

- :class:`~ignite.metrics.metric.MetricUsage`
- :class:`~ignite.metrics.metric.EpochWise`
- :class:`~ignite.metrics.metric.RunningEpochWise`
- :class:`~ignite.metrics.metric.BatchWise`
- :class:`~ignite.metrics.metric.RunningBatchWise`
- :class:`~ignite.metrics.metric.SingleEpochRunningBatchWise`
- :class:`~ignite.metrics.metric.BatchFiltered`

Metrics and distributed computations
Expand Down Expand Up @@ -359,10 +362,22 @@ EpochWise
~~~~~~~~~
.. autoclass:: ignite.metrics.metric.EpochWise

RunningEpochWise
~~~~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.RunningEpochWise

BatchWise
~~~~~~~~~
.. autoclass:: ignite.metrics.metric.BatchWise

RunningBatchWise
~~~~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.RunningBatchWise

SingleEpochRunningBatchWise
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.SingleEpochRunningBatchWise

BatchFiltered
~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.BatchFiltered
Expand Down
5 changes: 3 additions & 2 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ignite.handlers.checkpoint import BaseSaveHandler
from ignite.handlers.param_scheduler import ParamScheduler
from ignite.metrics import RunningAverage
from ignite.metrics.metric import RunningBatchWise
from ignite.utils import deprecated


Expand Down Expand Up @@ -209,8 +210,8 @@ def output_transform(x: Any, index: int, name: str) -> Any:
)

for i, n in enumerate(output_names):
RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
trainer, n
RunningAverage(output_transform=partial(output_transform, index=i, name=n)).attach(
trainer, n, RunningBatchWise()
)

if with_pbars:
Expand Down
7 changes: 6 additions & 1 deletion ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,12 @@ def _collective_op(
device = self.device()
if isinstance(tensor, (Number, float)):
tensor_to_number = True
tensor = torch.tensor(tensor, device=device, dtype=self._collective_op_dtype)
# We had an undocumented precision loss (double to torch.float32) here.
if self._collective_op_dtype is None:
dtype = torch.double if isinstance(tensor, float) else torch.int64
else:
dtype = self._collective_op_dtype
tensor = torch.tensor(tensor, device=device, dtype=dtype)
elif isinstance(tensor, str):
tensor_to_str = True
max_length = self._get_max_length(tensor, device)
Expand Down
99 changes: 97 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
if TYPE_CHECKING:
from ignite.metrics.metrics_lambda import MetricsLambda

__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"]
__all__ = [
"Metric",
"MetricUsage",
"EpochWise",
"BatchWise",
"BatchFiltered",
"RunningEpochWise",
"RunningBatchWise",
"SingleEpochRunningBatchWise",
]


class MetricUsage:
Expand Down Expand Up @@ -74,6 +83,33 @@ def __init__(self) -> None:
)


class RunningEpochWise(EpochWise):
"""
Running epoch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric
usage. A metric with such a usage most likely accompanies an :class:`~.metrics.metric.EpochWise` one to compute
a running measure of it e.g. running average.

Metric's methods are triggered on the following engine events:

- :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED``
(See :class:`~ignite.engine.events.Events`).
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``EPOCH_COMPLETED``.
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``EPOCH_COMPLETED``.

Attributes:
usage_name: usage name string
"""

usage_name: str = "running_epoch_wise"

def __init__(self) -> None:
super(EpochWise, self).__init__(
started=Events.STARTED,
completed=Events.EPOCH_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class BatchWise(MetricUsage):
"""
Batch-wise usage of Metrics.
Expand All @@ -99,6 +135,59 @@ def __init__(self) -> None:
)


class RunningBatchWise(BatchWise):
"""
Running batch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric
usage. A metric with such a usage could for example accompany a :class:`~.metrics.metric.BatchWise` one to compute
a running measure of it e.g. running average.

Metric's methods are triggered on the following engine events:

- :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED``
(See :class:`~ignite.engine.events.Events`).
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``.

Attributes:
usage_name: usage name string
"""

usage_name: str = "running_batch_wise"

def __init__(self) -> None:
super(BatchWise, self).__init__(
started=Events.STARTED,
completed=Events.ITERATION_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class SingleEpochRunningBatchWise(BatchWise):
"""
Running batch-wise usage of Metrics in a single epoch. It's like :class:`~.metrics.metric.RunningBatchWise` metric
usage with the difference that is used during a single epoch.

Metric's methods are triggered on the following engine events:

- :meth:`~ignite.metrics.metric.Metric.started` on every ``EPOCH_STARTED``
(See :class:`~ignite.engine.events.Events`).
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``.

Attributes:
usage_name: usage name string
"""

usage_name: str = "single_epoch_running_batch_wise"

def __init__(self) -> None:
super(BatchWise, self).__init__(
started=Events.EPOCH_STARTED,
completed=Events.ITERATION_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class BatchFiltered(MetricUsage):
"""
Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.
Expand Down Expand Up @@ -346,10 +435,16 @@ def _check_usage(self, usage: Union[str, MetricUsage]) -> MetricUsage:
if isinstance(usage, str):
if usage == EpochWise.usage_name:
usage = EpochWise()
elif usage == RunningEpochWise.usage_name:
usage = RunningEpochWise()
elif usage == BatchWise.usage_name:
usage = BatchWise()
elif usage == RunningBatchWise.usage_name:
usage = RunningBatchWise()
else:
raise ValueError(f"usage should be 'EpochWise.usage_name' or 'BatchWise.usage_name', get {usage}")
raise ValueError(
f"usage should be '(Running)EpochWise.usage_name' or '(Running)BatchWise.usage_name', got {usage}"
)
if not isinstance(usage, MetricUsage):
raise TypeError(f"Unhandled usage type {type(usage)}")
return usage
Expand Down
87 changes: 66 additions & 21 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
import torch

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce
from ignite.engine import Engine
from ignite.metrics.metric import (
BatchWise,
EpochWise,
Metric,
MetricUsage,
reinit__is_reduced,
SingleEpochRunningBatchWise,
sync_all_reduce,
)

__all__ = ["RunningAverage"]

Expand All @@ -18,8 +26,6 @@ class RunningAverage(Metric):
alpha: running average decay factor, default 0.98
output_transform: a function to use to transform the output if `src` is None and
corresponds the output of process function. Otherwise it should be None.
epoch_bound: whether the running average should be reset after each epoch (defaults
to True).
device: specifies which device updates are accumulated on. Should be
None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will
use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
Expand Down Expand Up @@ -90,7 +96,6 @@ def __init__(
src: Optional[Metric] = None,
alpha: float = 0.98,
output_transform: Optional[Callable] = None,
epoch_bound: bool = True,
device: Optional[Union[str, torch.device]] = None,
):
if not (isinstance(src, Metric) or src is None):
Expand All @@ -105,7 +110,6 @@ def __init__(
raise ValueError("Argument device should be None if src is a Metric.")
self.src = src
self._get_src_value = self._get_metric_value
setattr(self, "iteration_completed", self._metric_iteration_completed)
device = src._device
else:
if output_transform is None:
Expand All @@ -119,7 +123,6 @@ def __init__(
device = torch.device("cpu")

self.alpha = alpha
self.epoch_bound = epoch_bound
super(RunningAverage, self).__init__(output_transform=output_transform, device=device) # type: ignore[arg-type]

@reinit__is_reduced
Expand All @@ -139,16 +142,62 @@ def compute(self) -> Union[torch.Tensor, float]:

return self._value

def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = EpochWise()) -> None:
if self.epoch_bound:
# restart average every epoch
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
else:
engine.add_event_handler(Events.STARTED, self.started)
# compute metric
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
# apply running average
engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name)
def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = SingleEpochRunningBatchWise()) -> None:
r"""
Attach the metric to the ``engine`` using the events determined by the ``usage``.

Args:
engine: the engine to get attached to.
name: by which, the metric is inserted into ``engine.state.metrics`` dictionary.
usage: the usage determining on which events the metric is reset, updated and computed. It should be an
instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table.

======================================================= ===========================================
``usage`` **class** **Description**
======================================================= ===========================================
:class:`~.metrics.metric.RunningBatchWise` Running average of the ``src`` metric or
``engine.state.output`` is computed across
batches. In the former case, on each batch,
``src`` is reset, updated and computed then
its value is retrieved.
:class:`~.metrics.metric.SingleEpochRunningBatchWise` Same as above but the running average is
computed across batches in an epoch so it
is reset at the end of the epoch. Default.
:class:`~.metrics.metric.RunningEpochWise` Running average of the ``src`` metric or
``engine.state.output`` is computed across
epochs. In the former case, ``src`` works
as if it was attached in a
:class:`~ignite.metrics.metric.EpochWise`
manner and its computed value is retrieved
at the end of the epoch. The latter case
doesn't make much sense for this usage as
the ``engine.state.output`` of the last
batch is retrieved then.
======================================================= ===========================================

``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not
given and it's computed and updated using ``src`, by manually calling its ``compute`` method, or
``engine.state.output`` at ``usage.COMPLETED``event.
Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by
``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``,
otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``.
"""
usage = self._check_usage(usage)
has_src = hasattr(self, "src") and isinstance(self.src, Metric)

src_usage = EpochWise() if isinstance(usage, EpochWise) else BatchWise()
if has_src and not engine.has_event_handler(self.src.started, src_usage.STARTED):
engine.add_event_handler(src_usage.STARTED, self.src.started)
engine.add_event_handler(usage.ITERATION_COMPLETED, self.src.iteration_completed)

if not has_src:
engine.add_event_handler(usage.ITERATION_COMPLETED, self.iteration_completed)

if not engine.has_event_handler(self.started, usage.STARTED):
engine.add_event_handler(usage.STARTED, self.started)
engine.add_event_handler(usage.COMPLETED, self.completed, name)

# detach?

def _get_metric_value(self) -> Union[torch.Tensor, float]:
return self.src.compute()
Expand All @@ -159,10 +208,6 @@ def _get_output_value(self) -> Union[torch.Tensor, float]:
output = cast(Union[torch.Tensor, float], self.src) / idist.get_world_size()
return output

def _metric_iteration_completed(self, engine: Engine) -> None:
self.src.started(engine)
self.src.iteration_completed(engine)

@reinit__is_reduced
def _output_update(self, output: Union[torch.Tensor, float]) -> None:
if isinstance(output, torch.Tensor):
Expand Down
Loading