diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index afc477f457e1..bd5038f08140 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -290,7 +290,10 @@ Complete list of usages - :class:`~ignite.metrics.metric.MetricUsage` - :class:`~ignite.metrics.metric.EpochWise` + - :class:`~ignite.metrics.metric.RunningEpochWise` - :class:`~ignite.metrics.metric.BatchWise` + - :class:`~ignite.metrics.metric.RunningBatchWise` + - :class:`~ignite.metrics.metric.SingleEpochRunningBatchWise` - :class:`~ignite.metrics.metric.BatchFiltered` Metrics and distributed computations @@ -359,10 +362,22 @@ EpochWise ~~~~~~~~~ .. autoclass:: ignite.metrics.metric.EpochWise +RunningEpochWise +~~~~~~~~~~~~~~~~ +.. autoclass:: ignite.metrics.metric.RunningEpochWise + BatchWise ~~~~~~~~~ .. autoclass:: ignite.metrics.metric.BatchWise +RunningBatchWise +~~~~~~~~~~~~~~~~ +.. autoclass:: ignite.metrics.metric.RunningBatchWise + +SingleEpochRunningBatchWise +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ignite.metrics.metric.SingleEpochRunningBatchWise + BatchFiltered ~~~~~~~~~~~~~ .. autoclass:: ignite.metrics.metric.BatchFiltered diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index ad456c80a923..acff852e8c19 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -33,6 +33,7 @@ from ignite.handlers.checkpoint import BaseSaveHandler from ignite.handlers.param_scheduler import ParamScheduler from ignite.metrics import RunningAverage +from ignite.metrics.metric import RunningBatchWise from ignite.utils import deprecated @@ -209,8 +210,8 @@ def output_transform(x: Any, index: int, name: str) -> Any: ) for i, n in enumerate(output_names): - RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach( - trainer, n + RunningAverage(output_transform=partial(output_transform, index=i, name=n)).attach( + trainer, n, usage=RunningBatchWise() ) if with_pbars: diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 330311c8a78e..57c8f0d797a3 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -12,7 +12,16 @@ if TYPE_CHECKING: from ignite.metrics.metrics_lambda import MetricsLambda -__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"] +__all__ = [ + "Metric", + "MetricUsage", + "EpochWise", + "BatchWise", + "BatchFiltered", + "RunningEpochWise", + "RunningBatchWise", + "SingleEpochRunningBatchWise", +] class MetricUsage: @@ -31,6 +40,8 @@ class MetricUsage: :meth:`~ignite.metrics.metric.Metric.iteration_completed`. """ + usage_name: str + def __init__(self, started: Events, completed: Events, iteration_completed: CallableEventWithFilter) -> None: self.__started = started self.__completed = completed @@ -74,6 +85,33 @@ def __init__(self) -> None: ) +class RunningEpochWise(EpochWise): + """ + Running epoch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric + usage. A metric with such a usage most likely accompanies an :class:`~.metrics.metric.EpochWise` one to compute + a running measure of it e.g. running average. + + Metric's methods are triggered on the following engine events: + + - :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED`` + (See :class:`~ignite.engine.events.Events`). + - :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``EPOCH_COMPLETED``. + - :meth:`~ignite.metrics.metric.Metric.completed` on every ``EPOCH_COMPLETED``. + + Attributes: + usage_name: usage name string + """ + + usage_name: str = "running_epoch_wise" + + def __init__(self) -> None: + super(EpochWise, self).__init__( + started=Events.STARTED, + completed=Events.EPOCH_COMPLETED, + iteration_completed=Events.EPOCH_COMPLETED, + ) + + class BatchWise(MetricUsage): """ Batch-wise usage of Metrics. @@ -99,6 +137,59 @@ def __init__(self) -> None: ) +class RunningBatchWise(BatchWise): + """ + Running batch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric + usage. A metric with such a usage could for example accompany a :class:`~.metrics.metric.BatchWise` one to compute + a running measure of it e.g. running average. + + Metric's methods are triggered on the following engine events: + + - :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED`` + (See :class:`~ignite.engine.events.Events`). + - :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``. + - :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``. + + Attributes: + usage_name: usage name string + """ + + usage_name: str = "running_batch_wise" + + def __init__(self) -> None: + super(BatchWise, self).__init__( + started=Events.STARTED, + completed=Events.ITERATION_COMPLETED, + iteration_completed=Events.ITERATION_COMPLETED, + ) + + +class SingleEpochRunningBatchWise(BatchWise): + """ + Running batch-wise usage of Metrics in a single epoch. It's like :class:`~.metrics.metric.RunningBatchWise` metric + usage with the difference that is used during a single epoch. + + Metric's methods are triggered on the following engine events: + + - :meth:`~ignite.metrics.metric.Metric.started` on every ``EPOCH_STARTED`` + (See :class:`~ignite.engine.events.Events`). + - :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``. + - :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``. + + Attributes: + usage_name: usage name string + """ + + usage_name: str = "single_epoch_running_batch_wise" + + def __init__(self) -> None: + super(BatchWise, self).__init__( + started=Events.EPOCH_STARTED, + completed=Events.ITERATION_COMPLETED, + iteration_completed=Events.ITERATION_COMPLETED, + ) + + class BatchFiltered(MetricUsage): """ Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered. @@ -344,12 +435,16 @@ def completed(self, engine: Engine, name: str) -> None: def _check_usage(self, usage: Union[str, MetricUsage]) -> MetricUsage: if isinstance(usage, str): - if usage == EpochWise.usage_name: - usage = EpochWise() - elif usage == BatchWise.usage_name: - usage = BatchWise() - else: - raise ValueError(f"usage should be 'EpochWise.usage_name' or 'BatchWise.usage_name', get {usage}") + usages = [EpochWise, RunningEpochWise, BatchWise, RunningBatchWise, SingleEpochRunningBatchWise] + for usage_cls in usages: + if usage == usage_cls.usage_name: + usage = usage_cls() + break + if not isinstance(usage, MetricUsage): + raise ValueError( + "Argument usage should be '(Running)EpochWise.usage_name' or " + f"'((SingleEpoch)Running)BatchWise.usage_name', got {usage}" + ) if not isinstance(usage, MetricUsage): raise TypeError(f"Unhandled usage type {type(usage)}") return usage diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 6f73ab4f277d..16ced296dad6 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -1,10 +1,11 @@ -from typing import Callable, cast, Optional, Sequence, Union +import warnings +from typing import Any, Callable, cast, Optional, Union import torch import ignite.distributed as idist from ignite.engine import Engine, Events -from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, RunningBatchWise, SingleEpochRunningBatchWise __all__ = ["RunningAverage"] @@ -18,8 +19,10 @@ class RunningAverage(Metric): alpha: running average decay factor, default 0.98 output_transform: a function to use to transform the output if `src` is None and corresponds the output of process function. Otherwise it should be None. - epoch_bound: whether the running average should be reset after each epoch (defaults - to True). + epoch_bound: whether the running average should be reset after each epoch. It is depracated in favor of + ``usage`` argument in :meth:`attach` method. Setting ``epoch_bound`` to ``False`` is equivalent to + ``usage=SingleEpochRunningBatchWise()`` and setting it to ``True`` is equivalent to + ``usage=RunningBatchWise()`` in the :meth:`attach` method. Default None. device: specifies which device updates are accumulated on. Should be None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value @@ -90,7 +93,7 @@ def __init__( src: Optional[Metric] = None, alpha: float = 0.98, output_transform: Optional[Callable] = None, - epoch_bound: bool = True, + epoch_bound: Optional[bool] = None, device: Optional[Union[str, torch.device]] = None, ): if not (isinstance(src, Metric) or src is None): @@ -101,11 +104,13 @@ def __init__( if isinstance(src, Metric): if output_transform is not None: raise ValueError("Argument output_transform should be None if src is a Metric.") + + def output_transform(x: Any) -> Any: + return x + if device is not None: raise ValueError("Argument device should be None if src is a Metric.") - self.src = src - self._get_src_value = self._get_metric_value - setattr(self, "iteration_completed", self._metric_iteration_completed) + self.src: Union[Metric, None] = src device = src._device else: if output_transform is None: @@ -113,58 +118,105 @@ def __init__( "Argument output_transform should not be None if src corresponds " "to the output of process function." ) - self._get_src_value = self._get_output_value - setattr(self, "update", self._output_update) + self.src = None if device is None: device = torch.device("cpu") - self.alpha = alpha + if epoch_bound is not None: + warnings.warn( + "`epoch_bound` is deprecated and will be removed in the future. Consider using `usage` argument of" + "`attach` method instead. `epoch_bound=True` is equivalent with `usage=SingleEpochRunningBatchWise()`" + " and `epoch_bound=False` is equivalent with `usage=RunningBatchWise()`." + ) self.epoch_bound = epoch_bound - super(RunningAverage, self).__init__(output_transform=output_transform, device=device) # type: ignore[arg-type] + self.alpha = alpha + super(RunningAverage, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced def reset(self) -> None: self._value: Optional[Union[float, torch.Tensor]] = None + if isinstance(self.src, Metric): + self.src.reset() @reinit__is_reduced - def update(self, output: Sequence) -> None: - # Implement abstract method - pass - - def compute(self) -> Union[torch.Tensor, float]: - if self._value is None: - self._value = self._get_src_value() + def update(self, output: Union[torch.Tensor, float]) -> None: + if self.src is None: + output = output.detach().to(self._device, copy=True) if isinstance(output, torch.Tensor) else output + value = idist.all_reduce(output) / idist.get_world_size() else: - self._value = self._value * self.alpha + (1.0 - self.alpha) * self._get_src_value() + value = self.src.compute() + self.src.reset() - return self._value - - def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = EpochWise()) -> None: - if self.epoch_bound: - # restart average every epoch - engine.add_event_handler(Events.EPOCH_STARTED, self.started) + if self._value is None: + self._value = value else: - engine.add_event_handler(Events.STARTED, self.started) - # compute metric - engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) - # apply running average - engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name) - - def _get_metric_value(self) -> Union[torch.Tensor, float]: - return self.src.compute() - - @sync_all_reduce("src") - def _get_output_value(self) -> Union[torch.Tensor, float]: - # we need to compute average instead of sum produced by @sync_all_reduce("src") - output = cast(Union[torch.Tensor, float], self.src) / idist.get_world_size() - return output + self._value = self._value * self.alpha + (1.0 - self.alpha) * value - def _metric_iteration_completed(self, engine: Engine) -> None: - self.src.started(engine) - self.src.iteration_completed(engine) - - @reinit__is_reduced - def _output_update(self, output: Union[torch.Tensor, float]) -> None: - if isinstance(output, torch.Tensor): - output = output.detach().to(self._device, copy=True) - self.src = output # type: ignore[assignment] + def compute(self) -> Union[torch.Tensor, float]: + return cast(Union[torch.Tensor, float], self._value) + + def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None: + r""" + Attach the metric to the ``engine`` using the events determined by the ``usage``. + + Args: + engine: the engine to get attached to. + name: by which, the metric is inserted into ``engine.state.metrics`` dictionary. + usage: the usage determining on which events the metric is reset, updated and computed. It should be an + instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table. + + ======================================================= =========================================== + ``usage`` **class** **Description** + ======================================================= =========================================== + :class:`~.metrics.metric.RunningBatchWise` Running average of the ``src`` metric or + ``engine.state.output`` is computed across + batches. In the former case, on each batch, + ``src`` is reset, updated and computed then + its value is retrieved. Default. + :class:`~.metrics.metric.SingleEpochRunningBatchWise` Same as above but the running average is + computed across batches in an epoch so it + is reset at the end of the epoch. + :class:`~.metrics.metric.RunningEpochWise` Running average of the ``src`` metric or + ``engine.state.output`` is computed across + epochs. In the former case, ``src`` works + as if it was attached in a + :class:`~ignite.metrics.metric.EpochWise` + manner and its computed value is retrieved + at the end of the epoch. The latter case + doesn't make much sense for this usage as + the ``engine.state.output`` of the last + batch is retrieved then. + ======================================================= =========================================== + + ``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not + given and it's computed and updated using ``src``, by manually calling its ``compute`` method, or + ``engine.state.output`` at ``usage.COMPLETED`` event. + Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by + ``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``, + otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``. + + .. versionchanged:: 0.5.1 + Added `usage` argument + """ + usage = self._check_usage(usage) + if self.epoch_bound is not None: + usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise() + + if isinstance(self.src, Metric) and not engine.has_event_handler( + self.src.iteration_completed, Events.ITERATION_COMPLETED + ): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.src.iteration_completed) + + super().attach(engine, name, usage) + + def detach(self, engine: Engine, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None: + usage = self._check_usage(usage) + if self.epoch_bound is not None: + usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise() + + if isinstance(self.src, Metric) and engine.has_event_handler( + self.src.iteration_completed, Events.ITERATION_COMPLETED + ): + engine.remove_event_handler(self.src.iteration_completed, Events.ITERATION_COMPLETED) + + super().detach(engine, usage) diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index 592e711640d3..4b8555b0c758 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -11,7 +11,17 @@ import ignite.distributed as idist from ignite.engine import Engine, Events, State from ignite.metrics import ConfusionMatrix, Precision, Recall -from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, reinit__is_reduced, sync_all_reduce +from ignite.metrics.metric import ( + BatchFiltered, + BatchWise, + EpochWise, + Metric, + reinit__is_reduced, + RunningBatchWise, + RunningEpochWise, + SingleEpochRunningBatchWise, + sync_all_reduce, +) class DummyMetric1(Metric): @@ -839,80 +849,133 @@ def test_usage_exception(): m = DummyMetric2() with pytest.raises(TypeError, match=r"Unhandled usage type"): m.attach(engine, "dummy", usage=1) - with pytest.raises(ValueError, match=r"usage should be 'EpochWise.usage_name' or 'BatchWise.usage_name'"): + with pytest.raises( + ValueError, + match=r"usage should be '\(Running\)EpochWise.usage_name' or '\(\(SingleEpoch\)Running\)BatchWise.usage_name'", + ): m.attach(engine, "dummy", usage="fake") -def test_epochwise_usage(): - class MyMetric(Metric): - def __init__(self): - super(MyMetric, self).__init__() - self.value = [] +class DummyAccumulateInListMetric(Metric): + def __init__(self): + super(DummyAccumulateInListMetric, self).__init__() + self.value = [] - def reset(self): - self.value = [] + def reset(self): + self.value = [] - def compute(self): - return self.value + def compute(self): + return self.value + + def update(self, output): + self.value.append(output) - def update(self, output): - self.value.append(output) - def test(usage): - engine = Engine(lambda e, b: b) +@pytest.mark.parametrize("usage", ["epoch_wise", EpochWise.usage_name, EpochWise()]) +def test_epochwise_usage(usage): + engine = Engine(lambda e, b: b) - m = MyMetric() + m = DummyAccumulateInListMetric() - m.attach(engine, "ewm", usage=usage) + m.attach(engine, "ewm", usage=usage) - @engine.on(Events.EPOCH_COMPLETED) - def _(): - ewm = engine.state.metrics["ewm"] - assert len(ewm) == 3 - assert ewm == [0, 1, 2] + @engine.on(Events.EPOCH_COMPLETED) + def _(): + ewm = engine.state.metrics["ewm"] + assert len(ewm) == 3 + assert ewm == [0, 1, 2] - engine.run([0, 1, 2], max_epochs=10) - m.detach(engine, usage=usage) + engine.run([0, 1, 2], max_epochs=10) + m.detach(engine, usage=usage) - test("epoch_wise") - test(EpochWise.usage_name) - test(EpochWise()) +class DummyAccumulateMetric(Metric): + def __init__(self): + super(DummyAccumulateMetric, self).__init__() + self.value = 0 -def test_batchwise_usage(): - class MyMetric(Metric): - def __init__(self): - super(MyMetric, self).__init__() - self.value = [] + def reset(self): + self.value = 0 - def reset(self): - self.value = [] + def compute(self): + return self.value - def compute(self): - return self.value + def update(self, output): + self.value += output - def update(self, output): - self.value.append(output) - def test(usage): - engine = Engine(lambda e, b: b) +@pytest.mark.parametrize("usage", ["running_epoch_wise", RunningEpochWise.usage_name, RunningEpochWise()]) +def test_running_epochwise_usage(usage): + engine = Engine(lambda e, b: e.state.metrics["ewm"]) - m = MyMetric() + engine.state.metrics["ewm"] = 0 + + @engine.on(Events.EPOCH_STARTED) + def _(): + engine.state.metrics["ewm"] += 1 + + m = DummyAccumulateMetric() + m.attach(engine, "rewm", usage=usage) + + @engine.on(Events.EPOCH_COMPLETED) + def _(): + assert engine.state.metrics["rewm"] == sum(range(engine.state.epoch + 1)) + + engine.run([0, 1, 2], max_epochs=10) + + m.detach(engine, usage=usage) + + +@pytest.mark.parametrize("usage", ["batch_wise", BatchWise.usage_name, BatchWise()]) +def test_batchwise_usage(usage): + engine = Engine(lambda e, b: b) + + m = DummyAccumulateInListMetric() + + m.attach(engine, "bwm", usage=usage) + + @engine.on(Events.ITERATION_COMPLETED) + def _(): + bwm = engine.state.metrics["bwm"] + assert len(bwm) == 1 + assert bwm[0] == (engine.state.iteration - 1) % 3 + + engine.run([0, 1, 2], max_epochs=10) + m.detach(engine, usage=usage) - m.attach(engine, "bwm", usage=usage) - @engine.on(Events.ITERATION_COMPLETED) - def _(): - bwm = engine.state.metrics["bwm"] - assert len(bwm) == 1 - assert bwm[0] == (engine.state.iteration - 1) % 3 +@pytest.mark.parametrize("usage", ["running_batch_wise", RunningBatchWise.usage_name, RunningBatchWise()]) +def test_running_batchwise_usage(usage): + engine = Engine(lambda e, b: b) + + m = DummyAccumulateMetric() + m.attach(engine, "rbwm", usage=usage) + + @engine.on(Events.EPOCH_COMPLETED) + def _(): + assert engine.state.metrics["rbwm"] == 6 * engine.state.epoch + + engine.run([0, 1, 2, 3], max_epochs=10) + + m.detach(engine, usage=usage) - engine.run([0, 1, 2], max_epochs=10) - m.detach(engine, usage=usage) - test("batch_wise") - test(BatchWise.usage_name) - test(BatchWise()) +@pytest.mark.parametrize( + "usage", ["single_epoch_running_batch_wise", SingleEpochRunningBatchWise.usage_name, SingleEpochRunningBatchWise()] +) +def test_single_epoch_running_batchwise_usage(usage): + engine = Engine(lambda e, b: b) + + m = DummyAccumulateMetric() + m.attach(engine, "rbwm", usage=usage) + + @engine.on(Events.EPOCH_COMPLETED) + def _(): + assert engine.state.metrics["rbwm"] == 6 + + engine.run([0, 1, 2, 3], max_epochs=10) + + m.detach(engine, usage=usage) def test_batchfiltered_usage(): diff --git a/tests/ignite/metrics/test_running_average.py b/tests/ignite/metrics/test_running_average.py index d5e0a3bcc7f1..4511bd73bf20 100644 --- a/tests/ignite/metrics/test_running_average.py +++ b/tests/ignite/metrics/test_running_average.py @@ -1,5 +1,6 @@ -import os +import warnings from functools import partial +from itertools import accumulate import numpy as np import pytest @@ -8,6 +9,7 @@ import ignite.distributed as idist from ignite.engine import Engine, Events from ignite.metrics import Accuracy, RunningAverage +from ignite.metrics.metric import RunningBatchWise, RunningEpochWise, SingleEpochRunningBatchWise def test_wrong_input_args(): @@ -26,171 +28,138 @@ def test_wrong_input_args(): with pytest.raises(ValueError, match=r"Argument device should be None if src is a Metric"): RunningAverage(Accuracy(), device="cpu") + with pytest.warns(UserWarning, match=r"`epoch_bound` is deprecated and will be removed in the future."): + m = RunningAverage(Accuracy(), epoch_bound=True) -def test_integration(): - n_iters = 100 + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize("epoch_bound, usage", [(False, RunningBatchWise()), (True, SingleEpochRunningBatchWise())]) +def test_epoch_bound(epoch_bound, usage): + with warnings.catch_warnings(): + metric = RunningAverage(output_transform=lambda _: _, epoch_bound=epoch_bound) + e1 = Engine(lambda _, __: None) + e2 = Engine(lambda _, __: None) + metric.attach(e1, "") + metric.epoch_bound = None + metric.attach(e2, "", usage) + e1._event_handlers == e2._event_handlers + + +@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise()]) +def test_integration_batchwise(usage): + torch.manual_seed(10) + alpha = 0.98 + n_iters = 10 batch_size = 10 n_classes = 10 - y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_iters, batch_size))) - y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) - loss_values = iter(range(n_iters)) + max_epochs = 3 + data = list(range(n_iters)) + loss = torch.arange(n_iters) + y_true = torch.randint(0, n_classes, size=(n_iters, batch_size)) + y_pred = torch.rand(n_iters, batch_size, n_classes) + + accuracy_running_averages = torch.tensor( + list( + accumulate( + map( + lambda y_yp: torch.sum(y_yp[1].argmax(dim=-1) == y_yp[0]).item() / y_yp[0].size(0), + zip( + y_true if isinstance(usage, SingleEpochRunningBatchWise) else y_true.repeat(max_epochs, 1), + y_pred if isinstance(usage, SingleEpochRunningBatchWise) else y_pred.repeat(max_epochs, 1, 1), + ), + ), + lambda ra, acc: ra * alpha + (1 - alpha) * acc, + ) + ) + ) + if isinstance(usage, SingleEpochRunningBatchWise): + accuracy_running_averages = accuracy_running_averages.repeat(max_epochs) + + loss_running_averages = torch.tensor( + list( + accumulate( + loss if isinstance(usage, SingleEpochRunningBatchWise) else loss.repeat(max_epochs), + lambda ra, loss_item: ra * alpha + (1 - alpha) * loss_item, + ) + ) + ) + if isinstance(usage, SingleEpochRunningBatchWise): + loss_running_averages = loss_running_averages.repeat(max_epochs) - def update_fn(engine, batch): - loss_value = next(loss_values) - y_true_batch = next(y_true_batch_values) - y_pred_batch = next(y_pred_batch_values) - return loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + def update_fn(_, i): + loss_value = loss[i] + y_true_batch = y_true[i] + y_pred_batch = y_pred[i] + return loss_value, y_pred_batch, y_true_batch trainer = Engine(update_fn) - alpha = 0.98 acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) - acc_metric.attach(trainer, "running_avg_accuracy") + acc_metric.attach(trainer, "running_avg_accuracy", usage) avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) - avg_output.attach(trainer, "running_avg_output") - - running_avg_acc = [ - None, - ] - - @trainer.on(Events.ITERATION_COMPLETED) - def manual_running_avg_acc(engine): - _, y_pred, y = engine.state.output - indices = torch.max(y_pred, 1)[1] - correct = torch.eq(indices, y).view(-1) - num_correct = torch.sum(correct).item() - num_examples = correct.shape[0] - batch_acc = num_correct * 1.0 / num_examples - if running_avg_acc[0] is None: - running_avg_acc[0] = batch_acc - else: - running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc - engine.state.running_avg_acc = running_avg_acc[0] - - @trainer.on(Events.EPOCH_STARTED) - def running_avg_output_init(engine): - engine.state.running_avg_output = None + avg_output.attach(trainer, "running_avg_loss", usage) - @trainer.on(Events.ITERATION_COMPLETED) - def running_avg_output_update(engine): - if engine.state.running_avg_output is None: - engine.state.running_avg_output = engine.state.output[0] - else: - engine.state.running_avg_output = ( - engine.state.running_avg_output * alpha + (1.0 - alpha) * engine.state.output[0] - ) + metric_acc_running_averages = [] + metric_loss_running_averages = [] @trainer.on(Events.ITERATION_COMPLETED) - def assert_equal_running_avg_acc_values(engine): - assert ( - engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"] - ), f"{engine.state.running_avg_acc} vs {engine.state.metrics['running_avg_accuracy']}" + def _(engine): + metric_acc_running_averages.append(engine.state.metrics["running_avg_accuracy"]) + metric_loss_running_averages.append(engine.state.metrics["running_avg_loss"]) - @trainer.on(Events.ITERATION_COMPLETED) - def assert_equal_running_avg_output_values(engine): - assert ( - engine.state.running_avg_output == engine.state.metrics["running_avg_output"] - ), f"{engine.state.running_avg_output} vs {engine.state.metrics['running_avg_output']}" + trainer.run(data, max_epochs=3) - np.random.seed(10) - running_avg_acc = [ - None, - ] - n_iters = 10 - batch_size = 10 - n_classes = 10 - data = list(range(n_iters)) - loss_values = iter(range(n_iters)) - y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_iters, batch_size))) - y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) - trainer.run(data, max_epochs=1) - - running_avg_acc = [ - None, - ] - n_iters = 10 - batch_size = 10 - n_classes = 10 - data = list(range(n_iters)) - loss_values = iter(range(n_iters)) - y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_iters, batch_size))) - y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) - trainer.run(data, max_epochs=1) + assert (torch.tensor(metric_acc_running_averages) == accuracy_running_averages).all() + assert (torch.tensor(metric_loss_running_averages) == loss_running_averages).all() -def test_epoch_unbound(): +def test_integration_epochwise(): + torch.manual_seed(10) + alpha = 0.98 n_iters = 10 - n_epochs = 3 batch_size = 10 n_classes = 10 + max_epochs = 3 data = list(range(n_iters)) - loss_values = iter(range(2 * n_epochs * n_iters)) - y_true_batch_values = iter(np.random.randint(0, n_classes, size=(2 * n_epochs * n_iters, batch_size))) - y_pred_batch_values = iter(np.random.rand(2 * n_epochs * n_iters, batch_size, n_classes)) + y_true = torch.randint(0, n_classes, size=(n_iters, batch_size)) + y_pred = torch.rand(max_epochs, n_iters, batch_size, n_classes) + + accuracy_running_averages = torch.tensor( + list( + accumulate( + map( + lambda y_pred_epoch: torch.sum(y_pred_epoch.argmax(dim=-1) == y_true).item() / y_true.numel(), + y_pred, + ), + lambda ra, acc: ra * alpha + (1 - alpha) * acc, + ) + ) + ) - def update_fn(engine, batch): - loss_value = next(loss_values) - y_true_batch = next(y_true_batch_values) - y_pred_batch = next(y_pred_batch_values) - return loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + def update_fn(engine, i): + y_true_batch = y_true[i] + y_pred_batch = y_pred[engine.state.epoch - 1, i] + return y_pred_batch, y_true_batch trainer = Engine(update_fn) - alpha = 0.98 - - acc_metric = RunningAverage(Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha, epoch_bound=False) - acc_metric.attach(trainer, "running_avg_accuracy") - - avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha, epoch_bound=False) - avg_output.attach(trainer, "running_avg_output") - - running_avg_acc = [None] - - trainer.state.running_avg_output = None - @trainer.on(Events.ITERATION_COMPLETED, running_avg_acc) - def manual_running_avg_acc(engine, running_avg_acc): - _, y_pred, y = engine.state.output - indices = torch.max(y_pred, 1)[1] - correct = torch.eq(indices, y).view(-1) - num_correct = torch.sum(correct).item() - num_examples = correct.shape[0] - batch_acc = num_correct * 1.0 / num_examples - if running_avg_acc[0] is None: - running_avg_acc[0] = batch_acc - else: - running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc - engine.state.running_avg_acc = running_avg_acc[0] + acc_metric = RunningAverage(Accuracy(), alpha=alpha) + acc_metric.attach(trainer, "running_avg_accuracy", RunningEpochWise()) - @trainer.on(Events.ITERATION_COMPLETED) - def running_avg_output_update(engine): - if engine.state.running_avg_output is None: - engine.state.running_avg_output = engine.state.output[0] - else: - engine.state.running_avg_output = ( - engine.state.running_avg_output * alpha + (1.0 - alpha) * engine.state.output[0] - ) + metric_acc_running_averages = [] - @trainer.on(Events.ITERATION_COMPLETED) - def assert_equal_running_avg_acc_values(engine): - assert ( - engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"] - ), f"{engine.state.running_avg_acc} vs {engine.state.metrics['running_avg_accuracy']}" - - @trainer.on(Events.ITERATION_COMPLETED) - def assert_equal_running_avg_output_values(engine): - assert ( - engine.state.running_avg_output == engine.state.metrics["running_avg_output"] - ), f"{engine.state.running_avg_output} vs {engine.state.metrics['running_avg_output']}" + @trainer.on(Events.EPOCH_COMPLETED) + def _(engine): + metric_acc_running_averages.append(engine.state.metrics["running_avg_accuracy"]) trainer.run(data, max_epochs=3) - running_avg_acc[0] = None - trainer.state.running_avg_output = None - trainer.run(data, max_epochs=3) + assert (torch.tensor(metric_acc_running_averages) == accuracy_running_averages).all() -def test_multiple_attach(): +@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise(), RunningEpochWise()]) +def test_multiple_attach(usage): n_iters = 100 errD_values = iter(np.random.rand(n_iters)) errG_values = iter(np.random.rand(n_iters)) @@ -214,9 +183,9 @@ def update_fn(engine, batch): monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"] for metric in monitoring_metrics: foo = partial(lambda x, metric: x[metric], metric=metric) - RunningAverage(alpha=alpha, output_transform=foo).attach(trainer, metric) + RunningAverage(alpha=alpha, output_transform=foo).attach(trainer, metric, usage) - @trainer.on(Events.ITERATION_COMPLETED) + @trainer.on(usage.COMPLETED) def check_values(engine): values = [] for metric in monitoring_metrics: @@ -229,6 +198,22 @@ def check_values(engine): trainer.run(data) +@pytest.mark.filterwarnings("ignore") +@pytest.mark.parametrize("epoch_bound", [True, False, None]) +@pytest.mark.parametrize("src", [Accuracy(), None]) +@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise(), RunningEpochWise()]) +def test_detach(epoch_bound, src, usage): + with warnings.catch_warnings(): + m = RunningAverage(src, output_transform=(lambda _: _) if src is None else None, epoch_bound=epoch_bound) + e = Engine(lambda _, __: None) + m.attach(e, "m", usage) + for event_handlers in e._event_handlers.values(): + assert len(event_handlers) != 0 + m.detach(e, usage) + for event_handlers in e._event_handlers.values(): + assert len(event_handlers) == 0 + + def test_output_is_tensor(): m = RunningAverage(output_transform=lambda x: x) m.update(torch.rand(10, requires_grad=True).mean()) @@ -247,17 +232,18 @@ def test_output_is_tensor(): assert not v.requires_grad -def _test_distrib_on_output(device): +@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise()]) +def test_distrib_on_output(distributed, usage): + device = idist.device() rank = idist.get_rank() n_iters = 10 n_epochs = 3 - batch_size = 10 # Data per rank data = list(range(n_iters)) - k = n_epochs * batch_size * n_iters - all_loss_values = torch.arange(0, k * idist.get_world_size(), dtype=torch.float64).to(device) - loss_values = iter(all_loss_values[k * rank : k * (rank + 1)]) + rank_loss_count = n_epochs * n_iters + all_loss_values = torch.arange(0, rank_loss_count * idist.get_world_size(), dtype=torch.float64).to(device) + loss_values = iter(all_loss_values[rank_loss_count * rank : rank_loss_count * (rank + 1)]) def update_fn(engine, batch): loss_value = next(loss_values) @@ -266,35 +252,37 @@ def update_fn(engine, batch): trainer = Engine(update_fn) alpha = 0.98 - metric_device = idist.device() if torch.device(device).type != "xla" else "cpu" - avg_output = RunningAverage(output_transform=lambda x: x, alpha=alpha, epoch_bound=False, device=metric_device) - avg_output.attach(trainer, "running_avg_output") + metric_device = device if device.type != "xla" else "cpu" + avg_output = RunningAverage(output_transform=lambda x: x, alpha=alpha, device=metric_device) + avg_output.attach(trainer, "running_avg_output", usage) - @trainer.on(Events.STARTED) - def running_avg_output_init(engine): + @trainer.on(usage.STARTED) + def reset_running_avg_output(engine): engine.state.running_avg_output = None - @trainer.on(Events.ITERATION_COMPLETED) + @trainer.on(usage.ITERATION_COMPLETED) def running_avg_output_update(engine): i = engine.state.iteration - 1 - o = sum([all_loss_values[i + j * k] for j in range(idist.get_world_size())]).item() + o = sum([all_loss_values[i + r * rank_loss_count] for r in range(idist.get_world_size())]).item() o /= idist.get_world_size() if engine.state.running_avg_output is None: engine.state.running_avg_output = o else: engine.state.running_avg_output = engine.state.running_avg_output * alpha + (1.0 - alpha) * o - @trainer.on(Events.ITERATION_COMPLETED) + @trainer.on(usage.COMPLETED) def assert_equal_running_avg_output_values(engine): it = engine.state.iteration - assert engine.state.running_avg_output == pytest.approx( - engine.state.metrics["running_avg_output"] + assert ( + engine.state.running_avg_output == engine.state.metrics["running_avg_output"] ), f"{it}: {engine.state.running_avg_output} vs {engine.state.metrics['running_avg_output']}" trainer.run(data, max_epochs=3) -def _test_distrib_on_metric(device): +@pytest.mark.parametrize("usage", [RunningBatchWise(), SingleEpochRunningBatchWise(), RunningEpochWise()]) +def test_distrib_on_metric(distributed, usage): + device = idist.device() rank = idist.get_rank() n_iters = 10 n_epochs = 3 @@ -320,10 +308,8 @@ def update_fn(engine, batch): trainer = Engine(update_fn) alpha = 0.98 - acc_metric = RunningAverage( - Accuracy(output_transform=lambda x: [x[0], x[1]], device=metric_device), alpha=alpha, epoch_bound=False - ) - acc_metric.attach(trainer, "running_avg_accuracy") + acc_metric = RunningAverage(Accuracy(device=metric_device), alpha=alpha) + acc_metric.attach(trainer, "running_avg_accuracy", usage) running_avg_acc = [ None, @@ -332,29 +318,37 @@ def update_fn(engine, batch): @trainer.on(Events.ITERATION_COMPLETED) def manual_running_avg_acc(engine): - i = engine.state.iteration - 1 + iteration = engine.state.iteration - true_acc_metric.reset() + if not isinstance(usage, RunningEpochWise) or ((iteration - 1) % n_iters) == 0: + true_acc_metric.reset() + if ((iteration - 1) % n_iters) == 0 and isinstance(usage, SingleEpochRunningBatchWise): + running_avg_acc[0] = None for j in range(idist.get_world_size()): output = ( - torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), - torch.from_numpy(all_y_true_batch_values[j, i, :]), + torch.from_numpy(all_y_pred_batch_values[j, iteration - 1, :, :]), + torch.from_numpy(all_y_true_batch_values[j, iteration - 1, :]), ) true_acc_metric.update(output) - batch_acc = true_acc_metric._num_correct.item() * 1.0 / true_acc_metric._num_examples + if not isinstance(usage, RunningEpochWise) or (iteration % n_iters) == 0: + batch_acc = true_acc_metric._num_correct.item() * 1.0 / true_acc_metric._num_examples - if running_avg_acc[0] is None: - running_avg_acc[0] = batch_acc - else: - running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc - engine.state.running_avg_acc = running_avg_acc[0] + if running_avg_acc[0] is None: + running_avg_acc[0] = batch_acc + else: + running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc + engine.state.running_avg_acc = running_avg_acc[0] @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_acc_values(engine): - assert ( - engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"] - ), f"{engine.state.running_avg_acc} vs {engine.state.metrics['running_avg_accuracy']}" + print(engine.state.iteration) + if not isinstance(usage, RunningEpochWise) or ( + (engine.state.iteration > 1) and ((engine.state.iteration % n_iters) == 1) + ): + assert ( + engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"] + ), f"{engine.state.running_avg_acc} vs {engine.state.metrics['running_avg_accuracy']}" trainer.run(data, max_epochs=3) @@ -363,7 +357,8 @@ def assert_equal_running_avg_acc_values(engine): _test(idist.device()) -def _test_distrib_accumulator_device(device): +def test_distrib_accumulator_device(distributed): + device = idist.device() metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) @@ -381,79 +376,3 @@ def _test_distrib_accumulator_device(device): assert ( avg._value.device == metric_device ), f"{type(avg._value.device)}:{avg._value.device} vs {type(metric_device)}:{metric_device}" - - -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test_distrib_nccl_gpu(distributed_context_single_node_nccl): - device = idist.device() - _test_distrib_on_output(device) - _test_distrib_on_metric(device) - _test_distrib_accumulator_device(device) - - -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): - device = idist.device() - _test_distrib_on_output(device) - _test_distrib_on_metric(device) - _test_distrib_accumulator_device(device) - - -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") -@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") -def test_distrib_hvd(gloo_hvd_executor): - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") - nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() - - gloo_hvd_executor(_test_distrib_on_output, (device,), np=nproc, do_init=True) - gloo_hvd_executor(_test_distrib_on_metric, (device,), np=nproc, do_init=True) - gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) - - -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): - device = idist.device() - _test_distrib_on_output(device) - _test_distrib_on_metric(device) - _test_distrib_accumulator_device(device) - - -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): - device = idist.device() - _test_distrib_on_output(device) - _test_distrib_on_metric(device) - _test_distrib_accumulator_device(device) - - -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_single_device_xla(): - device = idist.device() - _test_distrib_on_output(device) - _test_distrib_on_metric(device) - _test_distrib_accumulator_device(device) - - -def _test_distrib_xla_nprocs(index): - device = idist.device() - _test_distrib_on_output(device) - _test_distrib_on_metric(device) - _test_distrib_accumulator_device(device) - - -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_xla_nprocs(xmp_executor): - n = int(os.environ["NUM_TPU_WORKERS"]) - xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)