Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions ignite/contrib/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,41 +78,43 @@ def __init__(
device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
super(PrecisionRecallCurve, self).__init__(
precision_recall_curve_compute_fn,
precision_recall_curve_compute_fn, # type: ignore[arg-type]
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
)

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.")

_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
if ws > 1 and not self._is_reduced:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
self._is_reduced = True

if idist.get_rank() == 0:
# Run compute_fn on zero rank only
precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
precision = torch.tensor(precision)
recall = torch.tensor(recall)
# thresholds can have negative strides, not compatible with torch tensors
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
thresholds = torch.tensor(thresholds.copy())
else:
precision, recall, thresholds = None, None, None

if ws > 1:
# broadcast result to all processes
precision = idist.broadcast(precision, src=0, safe_mode=True)
recall = idist.broadcast(recall, src=0, safe_mode=True)
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)

return precision, recall, thresholds
if self._result is None:
_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
if ws > 1:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))

if idist.get_rank() == 0:
# Run compute_fn on zero rank only
precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
precision = torch.tensor(precision, device=_prediction_tensor.device)
recall = torch.tensor(recall, device=_prediction_tensor.device)
# thresholds can have negative strides, not compatible with torch tensors
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device)
else:
precision, recall, thresholds = None, None, None

if ws > 1:
# broadcast result to all processes
precision = idist.broadcast(precision, src=0, safe_mode=True)
recall = idist.broadcast(recall, src=0, safe_mode=True)
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)

self._result = (precision, recall, thresholds) # type: ignore[assignment]

return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)
6 changes: 3 additions & 3 deletions ignite/contrib/metrics/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def __init__(
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")

super(RocCurve, self).__init__(
roc_auc_curve_compute_fn,
roc_auc_curve_compute_fn, # type: ignore[arg-type]
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
)

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("RocCurve must have at least one example before it can be computed.")

Expand All @@ -180,7 +180,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if idist.get_rank() == 0:
# Run compute_fn on zero rank only
fpr, tpr, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
fpr, tpr, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
fpr = torch.tensor(fpr, device=_prediction_tensor.device)
tpr = torch.tensor(tpr, device=_prediction_tensor.device)
thresholds = torch.tensor(thresholds, device=_prediction_tensor.device)
Expand Down
45 changes: 22 additions & 23 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, cast, List, Tuple, Union
from typing import Callable, cast, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -28,9 +28,8 @@ class EpochMetric(Metric):
- ``update`` must receive output of the form ``(y_pred, y)``.

Args:
compute_fn: a callable with the signature (`torch.tensor`, `torch.tensor`) takes as the input
`predictions` and `targets` and returns a scalar. Input tensors will be on specified ``device``
(see arg below).
compute_fn: a callable which receives two tensors as the `predictions` and `targets`
and returns a scalar. Input tensors will be on specified ``device`` (see arg below).
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
Expand Down Expand Up @@ -70,7 +69,7 @@ def mse_fn(y_preds, y_targets):

def __init__(
self,
compute_fn: Callable,
compute_fn: Callable[[torch.Tensor, torch.Tensor], float],
output_transform: Callable = lambda x: x,
check_compute_fn: bool = True,
device: Union[str, torch.device] = torch.device("cpu"),
Expand All @@ -88,6 +87,7 @@ def __init__(
def reset(self) -> None:
self._predictions: List[torch.Tensor] = []
self._targets: List[torch.Tensor] = []
self._result: Optional[float] = None

def _check_shape(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
Expand Down Expand Up @@ -136,31 +136,30 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
except Exception as e:
warnings.warn(f"Probably, there can be a problem with `compute_fn`:\n {e}.", EpochMetricWarning)

def compute(self) -> Any:
def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")

_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)
if self._result is None:
_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
ws = idist.get_world_size()
if ws > 1:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))

if ws > 1 and not self._is_reduced:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
self._is_reduced = True
self._result = 0.0
if idist.get_rank() == 0:
# Run compute_fn on zero rank only
self._result = self.compute_fn(_prediction_tensor, _target_tensor)

result = 0.0
if idist.get_rank() == 0:
# Run compute_fn on zero rank only
result = self.compute_fn(_prediction_tensor, _target_tensor)
if ws > 1:
# broadcast result to all processes
self._result = cast(float, idist.broadcast(self._result, src=0))

if ws > 1:
# broadcast result to all processes
result = cast(float, idist.broadcast(result, src=0))

return result
return self._result


class EpochMetricWarning(UserWarning):
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def reinit__is_reduced(func: Callable) -> Callable:
def wrapper(self: Metric, *args: Any, **kwargs: Any) -> None:
func(self, *args, **kwargs)
self._is_reduced = False
if "_result" in self.__dict__:
self._result = None # type: ignore[attr-defined]

setattr(wrapper, "_decorated", True)
return wrapper
Expand Down
60 changes: 6 additions & 54 deletions tests/ignite/metrics/test_epoch_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import pytest
import torch

Expand Down Expand Up @@ -153,13 +151,11 @@ def compute_fn(y_preds, y_targets):
em.update(output1)


def _test_distrib_integration(device=None):

if device is None:
device = idist.device() if idist.device().type != "xla" else "cpu"
def test_distrib_integration(distributed):

device = idist.device() if idist.device().type != "xla" else "cpu"
rank = idist.get_rank()
torch.manual_seed(12 + rank)
torch.manual_seed(40 + rank)

n_iters = 3
batch_size = 2
Expand Down Expand Up @@ -188,51 +184,7 @@ def assert_data_fn(all_preds, all_targets):

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)
ep_metric_true = (y_preds.argmax(dim=1) == y_true).sum().item()

assert engine.state.metrics["epm"] == (y_preds.argmax(dim=1) == y_true).sum().item()


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

device = idist.device()
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

device = idist.device()
_test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
_test_distrib_integration()


def _test_distrib_xla_nprocs(index):
_test_distrib_integration()


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_integration, (None,), np=nproc, do_init=True)
assert engine.state.metrics["epm"] == ep_metric_true
assert ep_metric.compute() == ep_metric_true