Skip to content

Commit 8ffc64e

Browse files
committed
test_geometric_mean_relative_absolute_error.py test_kendall_correlation.py test_manhattan_distance.py test_maximum_absolute_error.py
1 parent 7af547f commit 8ffc64e

8 files changed

+305
-301
lines changed

tests/ignite/metrics/regression/test_canberra_metric.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,83 +20,83 @@ def test_wrong_input_shapes():
2020
m.update((torch.rand(4, 1), torch.rand(4)))
2121

2222

23-
def test_compute(available_device):
24-
a = torch.randn(4)
25-
b = torch.randn(4)
26-
c = torch.randn(4)
27-
d = torch.randn(4)
28-
ground_truth = torch.randn(4)
23+
def test_compute():
24+
a = np.random.randn(4)
25+
b = np.random.randn(4)
26+
c = np.random.randn(4)
27+
d = np.random.randn(4)
28+
ground_truth = np.random.randn(4)
2929

30-
m = CanberraMetric(device=available_device)
31-
assert m._device == torch.device(available_device)
30+
m = CanberraMetric()
3231

3332
canberra = DistanceMetric.get_metric("canberra")
3433

35-
m.update((a, ground_truth))
36-
np_sum = (torch.abs(ground_truth - a) / (torch.abs(a) + torch.abs(ground_truth))).sum()
34+
m.update((torch.from_numpy(a), torch.from_numpy(ground_truth)))
35+
np_sum = (np.abs(ground_truth - a) / (np.abs(a) + np.abs(ground_truth))).sum()
3736
assert m.compute() == pytest.approx(np_sum)
38-
assert canberra.pairwise([a.cpu().numpy(), ground_truth.cpu().numpy()])[0][1] == pytest.approx(np_sum)
37+
assert canberra.pairwise([a, ground_truth])[0][1] == pytest.approx(np_sum)
3938

40-
m.update((b, ground_truth))
41-
np_sum += ((torch.abs(ground_truth - b)) / (torch.abs(b) + torch.abs(ground_truth))).sum()
39+
m.update((torch.from_numpy(b), torch.from_numpy(ground_truth)))
40+
np_sum += ((np.abs(ground_truth - b)) / (np.abs(b) + np.abs(ground_truth))).sum()
4241
assert m.compute() == pytest.approx(np_sum)
4342
v1 = np.hstack([a, b])
4443
v2 = np.hstack([ground_truth, ground_truth])
4544
assert canberra.pairwise([v1, v2])[0][1] == pytest.approx(np_sum)
4645

47-
m.update((c, ground_truth))
48-
np_sum += ((torch.abs(ground_truth - c)) / (torch.abs(c) + torch.abs(ground_truth))).sum()
46+
m.update((torch.from_numpy(c), torch.from_numpy(ground_truth)))
47+
np_sum += ((np.abs(ground_truth - c)) / (np.abs(c) + np.abs(ground_truth))).sum()
4948
assert m.compute() == pytest.approx(np_sum)
5049
v1 = np.hstack([v1, c])
5150
v2 = np.hstack([v2, ground_truth])
5251
assert canberra.pairwise([v1, v2])[0][1] == pytest.approx(np_sum)
5352

54-
m.update((d, ground_truth))
55-
np_sum += (torch.abs(ground_truth - d) / (torch.abs(d) + torch.abs(ground_truth))).sum()
53+
m.update((torch.from_numpy(d), torch.from_numpy(ground_truth)))
54+
np_sum += (np.abs(ground_truth - d) / (np.abs(d) + np.abs(ground_truth))).sum()
5655
assert m.compute() == pytest.approx(np_sum)
5756
v1 = np.hstack([v1, d])
5857
v2 = np.hstack([v2, ground_truth])
5958
assert canberra.pairwise([v1, v2])[0][1] == pytest.approx(np_sum)
6059

6160

62-
@pytest.mark.parametrize("n_times", range(3))
63-
@pytest.mark.parametrize(
64-
"test_cases",
65-
[
66-
(torch.rand(size=(100,)), torch.rand(size=(100,)), 10),
67-
(torch.rand(size=(100, 1)), torch.rand(size=(100, 1)), 20),
68-
],
69-
)
70-
def test_integration(n_times, test_cases, available_device):
71-
y_pred, y, batch_size = test_cases
61+
def test_integration():
62+
def _test(y_pred, y, batch_size):
63+
def update_fn(engine, batch):
64+
idx = (engine.state.iteration - 1) * batch_size
65+
y_true_batch = np_y[idx : idx + batch_size]
66+
y_pred_batch = np_y_pred[idx : idx + batch_size]
67+
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
7268

73-
def update_fn(engine, batch):
74-
idx = (engine.state.iteration - 1) * batch_size
75-
y_true_batch = y[idx : idx + batch_size]
76-
y_pred_batch = y_pred[idx : idx + batch_size]
77-
return y_pred_batch, y_true_batch
69+
engine = Engine(update_fn)
7870

79-
engine = Engine(update_fn)
71+
m = CanberraMetric()
72+
m.attach(engine, "cm")
8073

81-
m = CanberraMetric(device=available_device)
82-
assert m._device == torch.device(available_device)
74+
np_y = y.numpy().ravel()
75+
np_y_pred = y_pred.numpy().ravel()
8376

84-
m.attach(engine, "cm")
77+
canberra = DistanceMetric.get_metric("canberra")
8578

86-
canberra = DistanceMetric.get_metric("canberra")
79+
data = list(range(y_pred.shape[0] // batch_size))
80+
cm = engine.run(data, max_epochs=1).metrics["cm"]
8781

88-
data = list(range(y_pred.shape[0] // batch_size))
89-
cm = engine.run(data, max_epochs=1).metrics["cm"]
82+
assert canberra.pairwise([np_y_pred, np_y])[0][1] == pytest.approx(cm)
9083

91-
pred_np = y_pred.cpu().numpy().reshape(len(y_pred), -1)
92-
true_np = y.cpu().numpy().reshape(len(y), -1)
93-
expected = np.sum(canberra.pairwise(pred_np, true_np).diagonal())
94-
assert expected == pytest.approx(cm)
84+
def get_test_cases():
85+
test_cases = [
86+
(torch.rand(size=(100,)), torch.rand(size=(100,)), 10),
87+
(torch.rand(size=(100, 1)), torch.rand(size=(100, 1)), 20),
88+
]
89+
return test_cases
9590

91+
for _ in range(5):
92+
# check multiple random inputs as random exact occurencies are rare
93+
test_cases = get_test_cases()
94+
for y_pred, y, batch_size in test_cases:
95+
_test(y_pred, y, batch_size)
9696

97-
def test_error_is_not_nan(available_device):
98-
m = CanberraMetric(device=available_device)
99-
assert m._device == torch.device(available_device)
97+
98+
def test_error_is_not_nan():
99+
m = CanberraMetric()
100100
m.update((torch.zeros(4), torch.zeros(4)))
101101
assert not (torch.isnan(m._sum_of_errors).any() or torch.isinf(m._sum_of_errors).any()), m._sum_of_errors
102102

tests/ignite/metrics/regression/test_fractional_absolute_error.py

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,65 +28,77 @@ def test_wrong_input_shapes():
2828
m.update((torch.rand(4, 1), torch.rand(4)))
2929

3030

31-
def test_compute(available_device):
32-
a = torch.randn(4)
33-
b = torch.randn(4)
34-
c = torch.randn(4)
35-
d = torch.randn(4)
36-
ground_truth = torch.randn(4)
31+
def test_compute():
32+
a = np.random.randn(4)
33+
b = np.random.randn(4)
34+
c = np.random.randn(4)
35+
d = np.random.randn(4)
36+
ground_truth = np.random.randn(4)
3737

38-
m = FractionalAbsoluteError(device=available_device)
39-
assert m._device == torch.device(available_device)
38+
m = FractionalAbsoluteError()
4039

41-
total_error = 0.0
42-
total_len = 0
40+
m.update((torch.from_numpy(a), torch.from_numpy(ground_truth)))
41+
np_sum = (2 * np.abs((a - ground_truth)) / (np.abs(a) + np.abs(ground_truth))).sum()
42+
np_len = len(a)
43+
np_ans = np_sum / np_len
44+
assert m.compute() == pytest.approx(np_ans)
4345

44-
for pred in [a, b, c, d]:
45-
m.update((pred, ground_truth))
46+
m.update((torch.from_numpy(b), torch.from_numpy(ground_truth)))
47+
np_sum += (2 * np.abs((b - ground_truth)) / (np.abs(b) + np.abs(ground_truth))).sum()
48+
np_len += len(b)
49+
np_ans = np_sum / np_len
50+
assert m.compute() == pytest.approx(np_ans)
4651

47-
# Compute fractional absolute error in PyTorch
48-
error = 2 * torch.abs(pred - ground_truth) / (torch.abs(pred) + torch.abs(ground_truth))
49-
total_error += error.sum().item()
50-
total_len += len(pred)
52+
m.update((torch.from_numpy(c), torch.from_numpy(ground_truth)))
53+
np_sum += (2 * np.abs((c - ground_truth)) / (np.abs(c) + np.abs(ground_truth))).sum()
54+
np_len += len(c)
55+
np_ans = np_sum / np_len
56+
assert m.compute() == pytest.approx(np_ans)
5157

52-
expected = total_error / total_len
53-
assert m.compute() == pytest.approx(expected)
58+
m.update((torch.from_numpy(d), torch.from_numpy(ground_truth)))
59+
np_sum += (2 * np.abs((d - ground_truth)) / (np.abs(d) + np.abs(ground_truth))).sum()
60+
np_len += len(d)
61+
np_ans = np_sum / np_len
62+
assert m.compute() == pytest.approx(np_ans)
5463

5564

56-
@pytest.mark.parametrize("n_times", range(5))
57-
@pytest.mark.parametrize(
58-
"test_cases",
59-
[
60-
(torch.rand(size=(100,)), torch.rand(size=(100,)), 10),
61-
(torch.rand(size=(100, 1)), torch.rand(size=(100, 1)), 20),
62-
],
63-
)
64-
def test_integration_fractional_absolute_error(n_times, test_cases, available_device):
65-
y_pred, y, batch_size = test_cases
65+
def test_integration():
66+
def _test(y_pred, y, batch_size):
67+
def update_fn(engine, batch):
68+
idx = (engine.state.iteration - 1) * batch_size
69+
y_true_batch = np_y[idx : idx + batch_size]
70+
y_pred_batch = np_y_pred[idx : idx + batch_size]
71+
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
6672

67-
np_y = y.numpy().ravel()
68-
np_y_pred = y_pred.numpy().ravel()
73+
engine = Engine(update_fn)
6974

70-
def update_fn(engine, batch):
71-
idx = (engine.state.iteration - 1) * batch_size
72-
y_true_batch = y[idx : idx + batch_size]
73-
y_pred_batch = y_pred[idx : idx + batch_size]
74-
return y_pred_batch, y_true_batch
75+
m = FractionalAbsoluteError()
76+
m.attach(engine, "fab")
7577

76-
engine = Engine(update_fn)
78+
np_y = y.numpy().ravel()
79+
np_y_pred = y_pred.numpy().ravel()
7780

78-
metric = FractionalAbsoluteError(device=available_device)
79-
assert metric._device == torch.device(available_device)
81+
data = list(range(y_pred.shape[0] // batch_size))
82+
fab = engine.run(data, max_epochs=1).metrics["fab"]
8083

81-
metric.attach(engine, "fab")
84+
np_sum = (2 * np.abs((np_y_pred - np_y)) / (np.abs(np_y_pred) + np.abs(np_y))).sum()
85+
np_len = len(y_pred)
86+
np_ans = np_sum / np_len
8287

83-
data = list(range(y_pred.shape[0] // batch_size))
84-
fab = engine.run(data, max_epochs=1).metrics["fab"]
88+
assert np_ans == pytest.approx(fab)
8589

86-
np_sum = (2 * np.abs(np_y_pred - np_y) / (np.abs(np_y_pred) + np.abs(np_y))).sum()
87-
expected = np_sum / len(np_y)
90+
def get_test_cases():
91+
test_cases = [
92+
(torch.rand(size=(100,)), torch.rand(size=(100,)), 10),
93+
(torch.rand(size=(100, 1)), torch.rand(size=(100, 1)), 20),
94+
]
95+
return test_cases
8896

89-
assert expected == pytest.approx(fab)
97+
for _ in range(5):
98+
# check multiple random inputs as random exact occurencies are rare
99+
test_cases = get_test_cases()
100+
for y_pred, y, batch_size in test_cases:
101+
_test(y_pred, y, batch_size)
90102

91103

92104
def _test_distrib_compute(device):

0 commit comments

Comments
 (0)