Skip to content
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 62 additions & 27 deletions tests/ignite/metrics/test_epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,73 +51,98 @@ def compute_fn(y_preds, y_targets):
em.compute()


def test_epoch_metric():
def test_epoch_metric(available_device):
def compute_fn(y_preds, y_targets):
return 0.0

em = EpochMetric(compute_fn)
device = torch.device(available_device)
em = EpochMetric(compute_fn, device=device)
assert em._device == device

em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
output1 = (torch.rand(4, 3, device=device), torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=device))
em.update(output1)
output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))

output2 = (torch.rand(4, 3, device=device), torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=device))
em.update(output2)

assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
if available_device == "cpu":
assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
assert torch.equal(em._predictions[0], output1[0])
assert torch.equal(em._predictions[1], output2[0])
assert torch.equal(em._targets[0], output1[1])
assert torch.equal(em._targets[1], output2[1])
assert em.compute() == 0.0

# test when y and y_pred are (batch_size, 1) that are squeezed to (batch_size, )
# test when y and y_pred are (batch_size, 1) that are squeezed to (batch_size,)
em.reset()
output1 = (torch.rand(4, 1), torch.randint(0, 2, size=(4, 1), dtype=torch.long))
output1 = (torch.rand(4, 1, device=device), torch.randint(0, 2, size=(4, 1), dtype=torch.long, device=device))
em.update(output1)
output2 = (torch.rand(4, 1), torch.randint(0, 2, size=(4, 1), dtype=torch.long))

output2 = (torch.rand(4, 1, device=device), torch.randint(0, 2, size=(4, 1), dtype=torch.long, device=device))
em.update(output2)

assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
if available_device == "cpu":
assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
assert torch.equal(em._predictions[0], output1[0][:, 0])
assert torch.equal(em._predictions[1], output2[0][:, 0])
assert torch.equal(em._targets[0], output1[1][:, 0])
assert torch.equal(em._targets[1], output2[1][:, 0])
assert em.compute() == 0.0


def test_mse_epoch_metric():
def test_mse_epoch_metric(available_device):
def compute_fn(y_preds, y_targets):
return torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2)).item()

em = EpochMetric(compute_fn)
em = EpochMetric(compute_fn, device=available_device)
assert em._device == torch.device(available_device)

em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
output1 = (
torch.rand(4, 3, device=available_device),
torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=available_device),
)
em.update(output1)
output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
output2 = (
torch.rand(4, 3, device=available_device),
torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=available_device),
)
em.update(output2)
output3 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
output3 = (
torch.rand(4, 3, device=available_device),
torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=available_device),
)
em.update(output3)

preds = torch.cat([output1[0], output2[0], output3[0]], dim=0)
targets = torch.cat([output1[1], output2[1], output3[1]], dim=0)

result = em.compute()
assert result == compute_fn(preds, targets)
assert result == pytest.approx(compute_fn(preds, targets), rel=1e-6)

em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
output1 = (
torch.rand(4, 3, device=available_device),
torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=available_device),
)
em.update(output1)
output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
output2 = (
torch.rand(4, 3, device=available_device),
torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=available_device),
)
em.update(output2)
output3 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
output3 = (
torch.rand(4, 3, device=available_device),
torch.randint(0, 2, size=(4, 3), dtype=torch.long, device=available_device),
)
em.update(output3)

preds = torch.cat([output1[0], output2[0], output3[0]], dim=0)
targets = torch.cat([output1[1], output2[1], output3[1]], dim=0)

result = em.compute()
assert result == compute_fn(preds, targets)
assert result == pytest.approx(compute_fn(preds, targets), rel=1e-6)


def test_bad_compute_fn():
Expand All @@ -135,18 +160,20 @@ def compute_fn(y_preds, y_targets):
em.update(output1)


def test_check_compute_fn():
def test_check_compute_fn(available_device):
def compute_fn(y_preds, y_targets):
raise Exception

em = EpochMetric(compute_fn, check_compute_fn=True)
em = EpochMetric(compute_fn, check_compute_fn=True, device=available_device)
assert em._device == torch.device(available_device)

em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output1)

em = EpochMetric(compute_fn, check_compute_fn=False)
em = EpochMetric(compute_fn, check_compute_fn=False, device=available_device)
assert em._device == torch.device(available_device)
em.update(output1)


Expand Down Expand Up @@ -188,19 +215,27 @@ def assert_data_fn(all_preds, all_targets):
assert ep_metric.compute() == ep_metric_true


def test_skip_unrolling():
def test_skip_unrolling(available_device):
def compute_fn(y_preds, y_targets):
return 0.0

em = EpochMetric(compute_fn, skip_unrolling=True)
em = EpochMetric(compute_fn, skip_unrolling=True, device=available_device)
assert em._device == torch.device(available_device)

em.reset()
output1 = (torch.rand(4, 2), torch.randint(0, 2, size=(4, 2), dtype=torch.long))
output1 = (
torch.rand(4, 2, device=available_device),
torch.randint(0, 2, size=(4, 2), dtype=torch.long, device=available_device),
)
em.update(output1)
output2 = (torch.rand(4, 2), torch.randint(0, 2, size=(4, 2), dtype=torch.long))
output2 = (
torch.rand(4, 2, device=available_device),
torch.randint(0, 2, size=(4, 2), dtype=torch.long, device=available_device),
)
em.update(output2)

assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
if available_device == "cpu":
assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
assert torch.equal(em._predictions[0], output1[0])
assert torch.equal(em._predictions[1], output2[0])
assert torch.equal(em._targets[0], output1[1])
Expand Down
Loading