Skip to content

Commit 8782ca4

Browse files
authored
modify import (#3293)
1 parent a5d3464 commit 8782ca4

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

ignite/metrics/precision_recall_curve.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99

1010
def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]:
11-
try:
12-
from sklearn.metrics import precision_recall_curve
13-
except ImportError:
14-
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")
11+
from sklearn.metrics import precision_recall_curve
1512

1613
y_true = y_targets.cpu().numpy()
1714
y_pred = y_preds.cpu().numpy()
@@ -83,6 +80,11 @@ def __init__(
8380
device: Union[str, torch.device] = torch.device("cpu"),
8481
skip_unrolling: bool = False,
8582
) -> None:
83+
try:
84+
from sklearn.metrics import precision_recall_curve # noqa: F401
85+
except ImportError:
86+
raise ModuleNotFoundError("This module requires scikit-learn to be installed.")
87+
8688
super(PrecisionRecallCurve, self).__init__(
8789
precision_recall_curve_compute_fn, # type: ignore[arg-type]
8890
output_transform=output_transform,

tests/ignite/metrics/test_precision_recall_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def mock_no_sklearn():
2121

2222

2323
def test_no_sklearn(mock_no_sklearn):
24-
with pytest.raises(ModuleNotFoundError, match=r"This contrib module requires scikit-learn to be installed."):
24+
with pytest.raises(ModuleNotFoundError, match=r"This module requires scikit-learn to be installed."):
2525
y = torch.tensor([1, 1])
2626
pr_curve = PrecisionRecallCurve()
2727
pr_curve.update((y, y))

0 commit comments

Comments
 (0)