diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 7e8d9396bd55..26cb3c12560d 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from functools import wraps from numbers import Number -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch @@ -214,7 +214,6 @@ def __init__( raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.") self._device = torch.device(device) - self._is_reduced = False self.reset() @abstractmethod @@ -556,25 +555,37 @@ def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable: "Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only" ) ws = idist.get_world_size() - if len(attrs) > 0 and not self._is_reduced: - if ws > 1: - for attr in attrs: - op_kwargs = {} - if ":" in attr: - attr, op = attr.split(":") - valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"] - if op not in valid_ops: - raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}") - op_kwargs["op"] = op - t = getattr(self, attr, None) - if t is not None: - t = idist.all_reduce(t, **op_kwargs) - self._is_reduced = True - setattr(self, attr, t) - else: - self._is_reduced = True - - return func(self, *args, **kwargs) + unreduced_attrs = {} + if len(attrs) > 0 and ws > 1: + for attr in attrs: + op_kwargs = {} + if ":" in attr: + attr, op = attr.split(":") + valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"] + if op not in valid_ops: + raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}") + op_kwargs["op"] = op + if attr not in self.__dict__: + raise ValueError(f"Metric {type(self)} has no attribute named `{attr}`.") + t = getattr(self, attr) + if not isinstance(t, (Number, torch.Tensor)): + raise TypeError( + "Attribute provided to sync_all_reduce should be a " + f"number or tensor but `{attr}` has type {type(t)}" + ) + unreduced_attrs[attr] = t + # Here `clone` is necessary since `idist.all_reduce` modifies `t` inplace in the case + # `t` is a tensor and its `device` is same as that of the process. + # TODO: Remove this dual behavior of `all_reduce` to always either return a new tensor or + # modify it in-place. + t_reduced = idist.all_reduce(cast(float, t) if isinstance(t, Number) else t.clone(), **op_kwargs) + setattr(self, attr, t_reduced) + + result = func(self, *args, **kwargs) + + for attr, value in unreduced_attrs.items(): + setattr(self, attr, value) + return result return another_wrapper @@ -594,7 +605,6 @@ def reinit__is_reduced(func: Callable) -> Callable: @wraps(func) def wrapper(self: Metric, *args: Any, **kwargs: Any) -> None: func(self, *args, **kwargs) - self._is_reduced = False if "_result" in self.__dict__: self._result = None # type: ignore[attr-defined] diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index f2fed6fde5b7..090651720a7f 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -6,7 +6,7 @@ import ignite.distributed as idist from ignite.exceptions import NotComputableError from ignite.metrics.accuracy import _BaseClassification -from ignite.metrics.metric import reinit__is_reduced +from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce from ignite.utils import to_onehot __all__ = ["Precision"] @@ -121,6 +121,7 @@ def reset(self) -> None: super(_BasePrecisionRecall, self).reset() + @sync_all_reduce("_numerator", "_denominator") def compute(self) -> Union[torch.Tensor, float]: # Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows. @@ -138,18 +139,13 @@ def compute(self) -> Union[torch.Tensor, float]: raise NotComputableError( f"{self.__class__.__name__} must have at least one example before it can be computed." ) - if not self._is_reduced: - self._numerator = idist.all_reduce(self._numerator) # type: ignore[assignment] - self._denominator = idist.all_reduce(self._denominator) # type: ignore[assignment] - if self._average == "weighted": - self._weight = idist.all_reduce(self._weight) # type: ignore[assignment] - self._is_reduced: bool = True fraction = self._numerator / (self._denominator + (self.eps if self._average != "samples" else 0)) if self._average == "weighted": - sum_of_weights = cast(torch.Tensor, self._weight).sum() + self.eps - return ((fraction @ self._weight) / sum_of_weights).item() # type: ignore + _weight = idist.all_reduce(self._weight.clone()) # type: ignore[union-attr] + sum_of_weights = cast(torch.Tensor, _weight).sum() + self.eps + return ((fraction @ _weight) / sum_of_weights).item() # type: ignore elif self._average == "micro" or self._average == "samples": return cast(torch.Tensor, fraction).item() elif self._average == "macro": diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index d313449d2c78..de827b4b0733 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -275,6 +275,9 @@ def _test(metric_device): acc._num_correct.device == metric_device ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + n = acc._num_examples + assert n == y.numel() / y.size(dim=1) + # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) @@ -282,9 +285,8 @@ def _test(metric_device): np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) assert acc._type == "multilabel" - n = acc._num_examples res = acc.compute() - assert n * idist.get_world_size() == acc._num_examples + assert n == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) @@ -298,6 +300,9 @@ def _test(metric_device): acc._num_correct.device == metric_device ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + n = acc._num_examples + assert n == y.numel() / y.size(dim=1) + # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) @@ -306,14 +311,13 @@ def _test(metric_device): np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) assert acc._type == "multilabel" - n = acc._num_examples res = acc.compute() - assert n * idist.get_world_size() == acc._num_examples + assert n == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) # check that result is not changed res = acc.compute() - assert n * idist.get_world_size() == acc._num_examples + assert n == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) @@ -334,6 +338,9 @@ def _test(metric_device): acc._num_correct.device == metric_device ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + n = acc._num_examples + assert n == y.numel() / y.size(dim=1) + # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) @@ -342,9 +349,8 @@ def _test(metric_device): np_y = to_numpy_multilabel(y.cpu()) # (N, C, L, ...) -> (N * L ..., C) assert acc._type == "multilabel" - n = acc._num_examples res = acc.compute() - assert n * idist.get_world_size() == acc._num_examples + assert n == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 882a07501568..cc369371c10c 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -148,7 +148,7 @@ def _test(metric_device, y_test_1, y_test_2): n = loss._num_examples assert n == len(y) res = loss.compute() - assert n * idist.get_world_size() == loss._num_examples + assert n == loss._num_examples y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) @@ -160,7 +160,7 @@ def _test(metric_device, y_test_1, y_test_2): loss.update((y_pred, y)) n = loss._num_examples res = loss.compute() - assert n * idist.get_world_size() == loss._num_examples + assert n == loss._num_examples y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index a4c9c1c76a3c..10d7a715bdd8 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -535,6 +535,29 @@ def update(self, output): pass +def _test_compute_with_sync_all_reduce_doesnt_change_attributes(device): + class DummyMetric3(Metric): + @reinit__is_reduced + def reset(self): + self.a = torch.tensor(0.0, device=self._device) + self.b = 0.0 + + def update(self, output): + self.a += torch.tensor(1.0) + self.b += 1.0 + + @sync_all_reduce("a", "b") + def compute(self): + return self.a.item(), self.b + + metric_device = device if torch.device(device).type != "xla" else "cpu" + metric = DummyMetric3(device=metric_device) + metric.update(None) + assert metric.a.item() == metric.b == 1.0 + metric.compute() + assert metric.a.item() == metric.b == 1.0 + + def _test_invalid_sync_all_reduce(device): class InvalidMetric(Metric): @reinit__is_reduced @@ -543,6 +566,7 @@ def reset(self): self.c = 0.0 self.n = 0 self.m = -1 + self.d = "a string" def compute(self): pass @@ -566,6 +590,14 @@ def invalid_reduction_op_3(self): def invalid_reduction_op_4(self): pass + @sync_all_reduce("missingattr") + def invalid_reduction_op_5(self): + pass + + @sync_all_reduce("d") + def invalid_reduction_op_6(self): + pass + metric_device = device if torch.device(device).type != "xla" else "cpu" m = InvalidMetric(device=metric_device) m.reset() @@ -583,6 +615,14 @@ def invalid_reduction_op_4(self): with pytest.raises(ValueError, match=r"Reduction operation is not valid"): m.invalid_reduction_op_4() + with pytest.raises(ValueError, match=r"has no attribute named `missingattr`."): + m.invalid_reduction_op_5() + + with pytest.raises( + TypeError, match=r"Attribute provided to sync_all_reduce should be a number or tensor but `d`" + ): + m.invalid_reduction_op_6() + def _test_distrib_sync_all_reduce_decorator(device): class DummyMetric(Metric): @@ -647,7 +687,7 @@ def update(self, output): m = DummyMetric(device=metric_device) m.update(None) m.compute() - # check if can call compute multiple times without all reduce invocation + # check if attributes are restored to their original values after previous `compute` m.compute() @@ -664,6 +704,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_invalid_sync_all_reduce(device) + _test_compute_with_sync_all_reduce_doesnt_change_attributes(device) @pytest.mark.distributed @@ -673,6 +714,7 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_invalid_sync_all_reduce(device) + _test_compute_with_sync_all_reduce_doesnt_change_attributes(device) @pytest.mark.distributed @@ -685,6 +727,7 @@ def test_distrib_hvd(gloo_hvd_executor): gloo_hvd_executor(_test_distrib_sync_all_reduce_decorator, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_invalid_sync_all_reduce, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_compute_with_sync_all_reduce_doesnt_change_attributes, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -695,6 +738,7 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_invalid_sync_all_reduce(device) + _test_compute_with_sync_all_reduce_doesnt_change_attributes(device) @pytest.mark.multinode_distributed @@ -705,6 +749,7 @@ def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_invalid_sync_all_reduce(device) + _test_compute_with_sync_all_reduce_doesnt_change_attributes(device) @pytest.mark.tpu @@ -715,6 +760,7 @@ def test_distrib_single_device_xla(): _test_distrib_sync_all_reduce_decorator(device) _test_creating_on_xla_fails(device) _test_invalid_sync_all_reduce(device) + _test_compute_with_sync_all_reduce_doesnt_change_attributes(device) def _test_distrib_xla_nprocs(index): @@ -722,6 +768,7 @@ def _test_distrib_xla_nprocs(index): _test_distrib_sync_all_reduce_decorator(device) _test_creating_on_xla_fails(device) _test_invalid_sync_all_reduce(device) + _test_compute_with_sync_all_reduce_doesnt_change_attributes(device) @pytest.mark.tpu