Skip to content
Merged
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
092da88
adds available_device to test_precision_recall_curve #3335
BanzaiTokyo Mar 28, 2025
4970f2c
forces float32 when converting to tensor on mps
BanzaiTokyo Mar 28, 2025
3c27487
Merge branch 'master' into test_precision_recall_curve_available_device
vfdev-5 Mar 29, 2025
fae2950
creates the data directly with torch tensors instead of numpy arrays
BanzaiTokyo Mar 29, 2025
41aa987
ensures compatibility with MPS by converting to float32
BanzaiTokyo Mar 29, 2025
e207a8f
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Mar 29, 2025
429d803
comments on float32 conversion
BanzaiTokyo Mar 29, 2025
51d2dc3
makes sure that sklearn does not convert float32 to float64
BanzaiTokyo Apr 14, 2025
af3ee49
another attempt of avoiding float64
BanzaiTokyo Apr 14, 2025
0d6f930
avoiding float64 for MPS
BanzaiTokyo Apr 14, 2025
622c2d7
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 14, 2025
1c96abe
avoiding float64 for MPS
BanzaiTokyo Apr 15, 2025
52972de
another attempt at avoiding float64 on MPS
BanzaiTokyo Apr 15, 2025
093e13a
moves conversion to float32 before assertions
BanzaiTokyo Apr 15, 2025
74185e5
conversion to float32
BanzaiTokyo Apr 15, 2025
d0be0a9
more conversion to float32
BanzaiTokyo Apr 15, 2025
80574ad
more conversion to float32
BanzaiTokyo Apr 15, 2025
e0fd412
more conversion to float32
BanzaiTokyo Apr 15, 2025
cf57e09
more conversion to float32
BanzaiTokyo Apr 15, 2025
b30cfcd
in precision_recall_curve.py add dtype when creating tensors for prec…
BanzaiTokyo Apr 23, 2025
7488bdf
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 24, 2025
5ef9215
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 24, 2025
949357f
removes unnecessary conversions
BanzaiTokyo Apr 24, 2025
fc4075c
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
9057d5d
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
fe00e65
move tensors to CPU before passing them to precision_recall_curve
BanzaiTokyo Apr 28, 2025
f3e4ae8
replace np.testing.assert_array_almost_equal with pytest.approx
BanzaiTokyo Apr 28, 2025
1adec2b
Merge branch 'master' into test_precision_recall_curve_available_device
BanzaiTokyo Apr 28, 2025
875b15e
removes manual_seed
BanzaiTokyo Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 80 additions & 49 deletions tests/ignite/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,112 +28,139 @@ def test_no_sklearn(mock_no_sklearn):
pr_curve.compute()


def test_precision_recall_curve():
def test_precision_recall_curve(available_device):
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(y_true.cpu().numpy(), y_pred.cpu().numpy())

precision_recall_curve_metric = PrecisionRecallCurve()
y_pred = torch.from_numpy(np_y_pred)
y = torch.from_numpy(np_y)
precision_recall_curve_metric = PrecisionRecallCurve(device=available_device)
assert precision_recall_curve_metric._device == torch.device(available_device)

precision_recall_curve_metric.update((y_pred, y))
precision_recall_curve_metric.update((y_pred, y_true))
precision, recall, thresholds = precision_recall_curve_metric.compute()
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()
# float32 ensures compatibility with MPS
precision = precision.cpu().float().numpy()
recall = recall.cpu().float().numpy()
thresholds = thresholds.cpu().float().numpy()

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
sk_precision = sk_precision.astype(np.float32)
sk_recall = sk_recall.astype(np.float32)
sk_thresholds = sk_thresholds.astype(np.float32)

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_integration_precision_recall_curve_with_output_transform():
np.random.seed(1)
def test_integration_precision_recall_curve_with_output_transform(available_device):
torch.manual_seed(1)
size = 100
np_y_pred = np.random.rand(size, 1)
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(y_true.cpu().numpy(), y_pred.cpu().numpy())

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (x[1], x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (x[1], x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
# float32 ensures compatibility with MPS
precision = precision.cpu().float().numpy()
recall = recall.cpu().float().numpy()
thresholds = thresholds.cpu().float().numpy()

sk_precision = sk_precision.astype(np.float32)
sk_recall = sk_recall.astype(np.float32)
sk_thresholds = sk_thresholds.astype(np.float32)

assert np.allclose(precision, sk_precision, rtol=1e-6)
assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_integration_precision_recall_curve_with_activated_output_transform():
def test_integration_precision_recall_curve_with_activated_output_transform(available_device):
np.random.seed(1)
size = 100
np_y_pred = np.random.rand(size, 1)
np_y_pred_sigmoid = torch.sigmoid(torch.from_numpy(np_y_pred)).numpy()
np_y = np.zeros((size,))
np_y[size // 2 :] = 1
np.random.shuffle(np_y)
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
y_true[size // 2 :] = 1.0
perm = torch.randperm(size)
y_pred = y_pred[perm]
y_true = y_true[perm]

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred_sigmoid)
sigmoid_y_pred = torch.sigmoid(y_pred).cpu().numpy()
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(y_true.cpu().numpy(), sigmoid_y_pred)

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
y_true_batch = y_true[idx : idx + batch_size]
y_pred_batch = y_pred[idx : idx + batch_size]
return idx, y_pred_batch, y_true_batch

engine = Engine(update_fn)

precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2]))
precision_recall_curve_metric = PrecisionRecallCurve(
output_transform=lambda x: (torch.sigmoid(x[1]), x[2]), device=available_device
)
assert precision_recall_curve_metric._device == torch.device(available_device)
precision_recall_curve_metric.attach(engine, "precision_recall_curve")

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.cpu().numpy()
recall = recall.cpu().numpy()
thresholds = thresholds.cpu().numpy()
# float32 ensures compatibility with MPS
precision = precision.cpu().float().numpy()
recall = recall.cpu().float().numpy()
thresholds = thresholds.cpu().float().numpy()

sk_precision = sk_precision.astype(np.float32)
sk_recall = sk_recall.astype(np.float32)
sk_thresholds = sk_thresholds.astype(np.float32)

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
assert np.allclose(precision, sk_precision, rtol=1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand, pytest.approx may convert float32 parameter into float64. This would break on MPS

assert np.allclose(recall, sk_recall, rtol=1e-6)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_check_compute_fn():
def test_check_compute_fn(available_device):
y_pred = torch.zeros((8, 13))
y_pred[:, 1] = 1
y_true = torch.zeros_like(y_pred)
output = (y_pred, y_true)

em = PrecisionRecallCurve(check_compute_fn=True)
em = PrecisionRecallCurve(check_compute_fn=True, device=available_device)
assert em._device == torch.device(available_device)

em.reset()
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output)

em = PrecisionRecallCurve(check_compute_fn=False)
em = PrecisionRecallCurve(check_compute_fn=False, device=available_device)
assert em._device == torch.device(available_device)
em.update(output)


Expand Down Expand Up @@ -227,6 +254,10 @@ def update(engine, i):

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)

sk_precision = sk_precision.astype(np.float32)
sk_recall = sk_recall.astype(np.float32)
sk_thresholds = sk_thresholds.astype(np.float32)

assert precision.shape == sk_precision.shape
assert recall.shape == sk_recall.shape
assert thresholds.shape == sk_thresholds.shape
Expand Down
Loading