diff --git a/ignite/contrib/metrics/roc_auc.py b/ignite/contrib/metrics/roc_auc.py index 79238de4d4f2..51c2c3779de1 100644 --- a/ignite/contrib/metrics/roc_auc.py +++ b/ignite/contrib/metrics/roc_auc.py @@ -1,7 +1,9 @@ -from typing import Any, Callable, Tuple, Union +from typing import Any, Callable, cast, Tuple, Union import torch +from ignite import distributed as idist +from ignite.exceptions import NotComputableError from ignite.metrics import EpochMetric @@ -103,6 +105,8 @@ class RocCurve(EpochMetric): `_ is run on the first batch of data to ensure there are no issues. User will be warned in case there are any issues computing the function. + device: optional device specification for internal storage. + Note: RocCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence values. To apply an activation to y_pred, use output_transform as shown below: @@ -137,9 +141,17 @@ def sigmoid_output_transform(output): FPR [0.0, 0.333, 0.333, 1.0] TPR [0.0, 0.0, 1.0, 1.0] Thresholds [2.0, 1.0, 0.711, 0.047] + + .. versionchanged:: 0.4.11 + added `device` argument """ - def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: + def __init__( + self, + output_transform: Callable = lambda x: x, + check_compute_fn: bool = False, + device: Union[str, torch.device] = torch.device("cpu"), + ) -> None: try: from sklearn.metrics import roc_curve # noqa: F401 @@ -147,5 +159,38 @@ def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: b raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.") super(RocCurve, self).__init__( - roc_auc_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn + roc_auc_curve_compute_fn, + output_transform=output_transform, + check_compute_fn=check_compute_fn, + device=device, ) + + def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if len(self._predictions) < 1 or len(self._targets) < 1: + raise NotComputableError("RocCurve must have at least one example before it can be computed.") + + _prediction_tensor = torch.cat(self._predictions, dim=0) + _target_tensor = torch.cat(self._targets, dim=0) + + ws = idist.get_world_size() + if ws > 1: + # All gather across all processes + _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) + _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor)) + + if idist.get_rank() == 0: + # Run compute_fn on zero rank only + fpr, tpr, thresholds = self.compute_fn(_prediction_tensor, _target_tensor) + fpr = torch.tensor(fpr) + tpr = torch.tensor(tpr) + thresholds = torch.tensor(thresholds) + else: + fpr, tpr, thresholds = None, None, None + + if ws > 1: + # broadcast result to all processes + fpr = idist.broadcast(fpr, src=0, safe_mode=True) + tpr = idist.broadcast(tpr, src=0, safe_mode=True) + thresholds = idist.broadcast(thresholds, src=0, safe_mode=True) + + return fpr, tpr, thresholds diff --git a/tests/ignite/contrib/metrics/test_roc_curve.py b/tests/ignite/contrib/metrics/test_roc_curve.py index c51293c40bb2..77ccb3a9d8b8 100644 --- a/tests/ignite/contrib/metrics/test_roc_curve.py +++ b/tests/ignite/contrib/metrics/test_roc_curve.py @@ -6,11 +6,22 @@ import torch from sklearn.metrics import roc_curve +from ignite import distributed as idist from ignite.contrib.metrics.roc_auc import RocCurve from ignite.engine import Engine +from ignite.exceptions import NotComputableError from ignite.metrics.epoch_metric import EpochMetricWarning +def test_wrong_setup(): + def compute_fn(y_preds, y_targets): + return 0.0 + + with pytest.raises(NotComputableError, match="RocCurve must have at least one example before it can be computed"): + metric = RocCurve(compute_fn) + metric.compute() + + @pytest.fixture() def mock_no_sklearn(): with patch.dict("sys.modules", {"sklearn.metrics": None}): @@ -121,3 +132,37 @@ def test_check_compute_fn(): em = RocCurve(check_compute_fn=False) em.update(output) + + +def test_distrib_integration(distributed): + rank = idist.get_rank() + torch.manual_seed(41 + rank) + n_batches, batch_size = 5, 10 + y = torch.randint(0, 2, size=(n_batches * batch_size,)) + y_pred = torch.rand((n_batches * batch_size,)) + + def update(engine, i): + return ( + y_pred[i * batch_size : (i + 1) * batch_size], + y[i * batch_size : (i + 1) * batch_size], + ) + + engine = Engine(update) + + device = "cpu" if idist.device().type == "xla" else idist.device() + metric = RocCurve(device=device) + metric.attach(engine, "roc_curve") + + data = list(range(n_batches)) + + engine.run(data=data, max_epochs=1) + + fpr, tpr, thresholds = engine.state.metrics["roc_curve"] + + y = idist.all_gather(y) + y_pred = idist.all_gather(y_pred) + sk_fpr, sk_tpr, sk_thresholds = roc_curve(y, y_pred) + + assert np.array_equal(fpr, sk_fpr) + assert np.array_equal(tpr, sk_tpr) + np.testing.assert_array_almost_equal(thresholds, sk_thresholds)