Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
092da88
adds available_device to test_precision_recall_curve #3335
BanzaiTokyo Mar 28, 2025
4970f2c
forces float32 when converting to tensor on mps
BanzaiTokyo Mar 28, 2025
3c27487
Merge branch 'master' into test_precision_recall_curve_available_device
vfdev-5 Mar 29, 2025
fae2950
creates the data directly with torch tensors instead of numpy arrays
BanzaiTokyo Mar 29, 2025
41aa987
ensures compatibility with MPS by converting to float32
BanzaiTokyo Mar 29, 2025
e207a8f
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Mar 29, 2025
429d803
comments on float32 conversion
BanzaiTokyo Mar 29, 2025
51d2dc3
makes sure that sklearn does not convert float32 to float64
BanzaiTokyo Apr 14, 2025
af3ee49
another attempt of avoiding float64
BanzaiTokyo Apr 14, 2025
0d6f930
avoiding float64 for MPS
BanzaiTokyo Apr 14, 2025
622c2d7
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 14, 2025
1c96abe
avoiding float64 for MPS
BanzaiTokyo Apr 15, 2025
52972de
another attempt at avoiding float64 on MPS
BanzaiTokyo Apr 15, 2025
093e13a
moves conversion to float32 before assertions
BanzaiTokyo Apr 15, 2025
74185e5
conversion to float32
BanzaiTokyo Apr 15, 2025
d0be0a9
more conversion to float32
BanzaiTokyo Apr 15, 2025
80574ad
more conversion to float32
BanzaiTokyo Apr 15, 2025
e0fd412
more conversion to float32
BanzaiTokyo Apr 15, 2025
cf57e09
more conversion to float32
BanzaiTokyo Apr 15, 2025
b30cfcd
in precision_recall_curve.py add dtype when creating tensors for prec…
BanzaiTokyo Apr 23, 2025
7488bdf
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 24, 2025
5ef9215
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 24, 2025
949357f
removes unnecessary conversions
BanzaiTokyo Apr 24, 2025
fc4075c
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
9057d5d
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
fe00e65
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
f3e4ae8
replace np.testing.assert_array_almost_equal with pytest.approx
BanzaiTokyo Apr 28, 2025
1adec2b
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 28, 2025
875b15e
removes manual_seed
BanzaiTokyo Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ignite/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i
if idist.get_rank() == 0:
# Run compute_fn on zero rank only
precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
precision = torch.tensor(precision, device=_prediction_tensor.device)
recall = torch.tensor(recall, device=_prediction_tensor.device)
precision = torch.tensor(precision, device=_prediction_tensor.device, dtype=self._double_dtype)
recall = torch.tensor(recall, device=_prediction_tensor.device, dtype=self._double_dtype)
# thresholds can have negative strides, not compatible with torch tensors
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device)
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device, dtype=self._double_dtype)
else:
precision, recall, thresholds = None, None, None

Expand Down
185 changes: 125 additions & 60 deletions tests/ignite/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve


def to_numpy_float32(x):
if isinstance(x, torch.Tensor):
if x.device.type == "mps":
x = x.to("cpu") # Explicitly move from MPS to CPU
return x.detach().to(dtype=torch.float32).numpy()
elif isinstance(x, np.ndarray):
return x.astype(np.float32)
return x


@pytest.fixture()
def mock_no_sklearn():
with patch.dict("sys.modules", {"sklearn.metrics": None}):
Expand All @@ -28,112 +38,144 @@ def test_no_sklearn(mock_no_sklearn):
pr_curve.compute()


def test_precision_recall_curve():
def test_precision_recall_curve(available_device):
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(to_numpy_float32(y_true), to_numpy_float32(y_pred))

precision_recall_curve_metric = PrecisionRecallCurve()
y_pred = torch.from_numpy(np_y_pred)
y = torch.from_numpy(np_y)
sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

precision_recall_curve_metric.update((y_pred, y))
precision_recall_curve_metric = PrecisionRecallCurve(device=available_device)
assert precision_recall_curve_metric._device == torch.device(available_device)

precision_recall_curve_metric.update((y_pred, y_true))
precision, recall, thresholds = precision_recall_curve_metric.compute()
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_integration_precision_recall_curve_with_output_transform():
np.random.seed(1)
def test_integration_precision_recall_curve_with_output_transform(available_device):
torch.manual_seed(1)
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(to_numpy_float32(y_true), to_numpy_float32(y_pred))

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (x[1], x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (x[1], x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall

precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_integration_precision_recall_curve_with_activated_output_transform():
def test_integration_precision_recall_curve_with_activated_output_transform(available_device):
np.random.seed(1)
size = 100
np_y_pred = np.random.rand(size, 1)
np_y_pred_sigmoid = torch.sigmoid(torch.from_numpy(np_y_pred)).numpy()
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred_sigmoid)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sigmoid_y_pred = torch.sigmoid(y_pred).cpu().numpy()
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(
to_numpy_float32(y_true), to_numpy_float32(sigmoid_y_pred)
)

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (torch.sigmoid(x[1]), x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.cpu().numpy()
recall = recall.cpu().numpy()
thresholds = thresholds.cpu().numpy()
precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

sk_precision = to_numpy_float32(sk_precision)
sk_recall = to_numpy_float32(sk_recall)
sk_thresholds = to_numpy_float32(sk_thresholds)

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
assert np.allclose(precision, sk_precision, rtol=1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand, pytest.approx may convert float32 parameter into float64. This would break on MPS

assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_check_compute_fn():
def test_check_compute_fn(available_device):
y_pred = torch.zeros((8, 13))
y_pred[:, 1] = 1
y_true = torch.zeros_like(y_pred)
output = (y_pred, y_true)

em = PrecisionRecallCurve(check_compute_fn=True)
em = PrecisionRecallCurve(check_compute_fn=True, device=available_device)
assert em._device == torch.device(available_device)

em.reset()
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output)

em = PrecisionRecallCurve(check_compute_fn=False)
em = PrecisionRecallCurve(check_compute_fn=False, device=available_device)
assert em._device == torch.device(available_device)
em.update(output)


Expand All @@ -157,15 +199,29 @@ def _test(y_pred, y, batch_size, metric_device):
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

np_y = y.cpu().numpy()
np_y_pred = y_pred.cpu().numpy()
np_y = to_numpy_float32(y)
np_y_pred = to_numpy_float32(y_pred)

res = prc.compute()

assert isinstance(res, Tuple)
assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0].cpu().numpy())
assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1].cpu().numpy())
assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2].cpu().numpy())
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)

assert np.allclose(
to_numpy_float32(res[0]),
to_numpy_float32(sk_precision),
rtol=1e-6,
)
assert np.allclose(
to_numpy_float32(res[1]),
to_numpy_float32(sk_recall),
rtol=1e-6,
)
assert np.allclose(
to_numpy_float32(res[2]),
to_numpy_float32(sk_thresholds),
rtol=1e-6,
)

def get_test_cases():
test_cases = [
Expand Down Expand Up @@ -222,17 +278,26 @@ def update(engine, i):

precision, recall, thresholds = engine.state.metrics["prc"]

np_y_true = y_true.cpu().numpy().ravel()
np_y_preds = y_preds.cpu().numpy().ravel()
np_y_true = to_numpy_float32(y_true).ravel()
np_y_preds = to_numpy_float32(y_preds).ravel()

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)

sk_precision = sk_precision.astype(np.float32)
sk_recall = sk_recall.astype(np.float32)
sk_thresholds = sk_thresholds.astype(np.float32)

precision = to_numpy_float32(precision)
recall = to_numpy_float32(recall)
thresholds = to_numpy_float32(thresholds)

assert precision.shape == sk_precision.shape
assert recall.shape == sk_recall.shape
assert thresholds.shape == sk_thresholds.shape
assert pytest.approx(precision.cpu().numpy()) == sk_precision
assert pytest.approx(recall.cpu().numpy()) == sk_recall
assert pytest.approx(thresholds.cpu().numpy()) == sk_thresholds

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)

metric_devices = ["cpu"]
if device.type != "xla":
Expand Down
Loading