Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ignite/metrics/roc_auc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, cast, Tuple, Union
from typing import Any, Callable, cast

import torch

Expand All @@ -15,7 +15,7 @@ def roc_auc_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float:
return roc_auc_score(y_true, y_pred)


def roc_auc_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]:
def roc_auc_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> tuple[Any, Any, Any]:
from sklearn.metrics import roc_curve

y_true = y_targets.cpu().numpy()
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
check_compute_fn: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
):
try:
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
check_compute_fn: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
Expand All @@ -182,7 +182,7 @@ def __init__(
skip_unrolling=skip_unrolling,
)

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("RocCurve must have at least one example before it can be computed.")

Expand All @@ -197,7 +197,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i

if idist.get_rank() == 0:
# Run compute_fn on zero rank only
fpr, tpr, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
fpr, tpr, thresholds = cast(tuple, self.compute_fn(_prediction_tensor, _target_tensor))
fpr = torch.tensor(fpr, dtype=self._double_dtype, device=_prediction_tensor.device)
tpr = torch.tensor(tpr, dtype=self._double_dtype, device=_prediction_tensor.device)
thresholds = torch.tensor(thresholds, dtype=self._double_dtype, device=_prediction_tensor.device)
Expand Down
Loading