diff --git a/tests/ignite/metrics/test_epoch_metric.py b/tests/ignite/metrics/test_epoch_metric.py index 5c42957cf57d..5bbb2e2307cc 100644 --- a/tests/ignite/metrics/test_epoch_metric.py +++ b/tests/ignite/metrics/test_epoch_metric.py @@ -51,11 +51,12 @@ def compute_fn(y_preds, y_targets): em.compute() -def test_epoch_metric(): +def test_epoch_metric(available_device): def compute_fn(y_preds, y_targets): return 0.0 - em = EpochMetric(compute_fn) + em = EpochMetric(compute_fn, device=available_device) + assert em._device == torch.device(available_device) em.reset() output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) @@ -63,11 +64,11 @@ def compute_fn(y_preds, y_targets): output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) em.update(output2) - assert all([t.device.type == "cpu" for t in em._predictions + em._targets]) - assert torch.equal(em._predictions[0], output1[0]) - assert torch.equal(em._predictions[1], output2[0]) - assert torch.equal(em._targets[0], output1[1]) - assert torch.equal(em._targets[1], output2[1]) + assert all([t.device.type == available_device for t in em._predictions + em._targets]) + assert torch.equal(em._predictions[0].cpu(), output1[0].cpu()) + assert torch.equal(em._predictions[1].cpu(), output2[0].cpu()) + assert torch.equal(em._targets[0].cpu(), output1[1].cpu()) + assert torch.equal(em._targets[1].cpu(), output2[1].cpu()) assert em.compute() == 0.0 # test when y and y_pred are (batch_size, 1) that are squeezed to (batch_size, ) @@ -77,19 +78,20 @@ def compute_fn(y_preds, y_targets): output2 = (torch.rand(4, 1), torch.randint(0, 2, size=(4, 1), dtype=torch.long)) em.update(output2) - assert all([t.device.type == "cpu" for t in em._predictions + em._targets]) - assert torch.equal(em._predictions[0], output1[0][:, 0]) - assert torch.equal(em._predictions[1], output2[0][:, 0]) - assert torch.equal(em._targets[0], output1[1][:, 0]) - assert torch.equal(em._targets[1], output2[1][:, 0]) + assert all([t.device.type == available_device for t in em._predictions + em._targets]) + assert torch.equal(em._predictions[0].cpu(), output1[0][:, 0].cpu()) + assert torch.equal(em._predictions[1].cpu(), output2[0][:, 0].cpu()) + assert torch.equal(em._targets[0].cpu(), output1[1][:, 0].cpu()) + assert torch.equal(em._targets[1].cpu(), output2[1][:, 0].cpu()) assert em.compute() == 0.0 -def test_mse_epoch_metric(): +def test_mse_epoch_metric(available_device): def compute_fn(y_preds, y_targets): return torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2)).item() - em = EpochMetric(compute_fn) + em = EpochMetric(compute_fn, device=available_device) + assert em._device == torch.device(available_device) em.reset() output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) @@ -103,7 +105,7 @@ def compute_fn(y_preds, y_targets): targets = torch.cat([output1[1], output2[1], output3[1]], dim=0) result = em.compute() - assert result == compute_fn(preds, targets) + assert result == pytest.approx(compute_fn(preds, targets), rel=1e-6) em.reset() output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) @@ -117,7 +119,7 @@ def compute_fn(y_preds, y_targets): targets = torch.cat([output1[1], output2[1], output3[1]], dim=0) result = em.compute() - assert result == compute_fn(preds, targets) + assert result == pytest.approx(compute_fn(preds, targets), rel=1e-6) def test_bad_compute_fn(): @@ -135,18 +137,20 @@ def compute_fn(y_preds, y_targets): em.update(output1) -def test_check_compute_fn(): +def test_check_compute_fn(available_device): def compute_fn(y_preds, y_targets): raise Exception - em = EpochMetric(compute_fn, check_compute_fn=True) + em = EpochMetric(compute_fn, check_compute_fn=True, device=available_device) + assert em._device == torch.device(available_device) em.reset() output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"): em.update(output1) - em = EpochMetric(compute_fn, check_compute_fn=False) + em = EpochMetric(compute_fn, check_compute_fn=False, device=available_device) + assert em._device == torch.device(available_device) em.update(output1) @@ -188,11 +192,12 @@ def assert_data_fn(all_preds, all_targets): assert ep_metric.compute() == ep_metric_true -def test_skip_unrolling(): +def test_skip_unrolling(available_device): def compute_fn(y_preds, y_targets): return 0.0 - em = EpochMetric(compute_fn, skip_unrolling=True) + em = EpochMetric(compute_fn, skip_unrolling=True, device=available_device) + assert em._device == torch.device(available_device) em.reset() output1 = (torch.rand(4, 2), torch.randint(0, 2, size=(4, 2), dtype=torch.long)) @@ -200,9 +205,9 @@ def compute_fn(y_preds, y_targets): output2 = (torch.rand(4, 2), torch.randint(0, 2, size=(4, 2), dtype=torch.long)) em.update(output2) - assert all([t.device.type == "cpu" for t in em._predictions + em._targets]) - assert torch.equal(em._predictions[0], output1[0]) - assert torch.equal(em._predictions[1], output2[0]) - assert torch.equal(em._targets[0], output1[1]) - assert torch.equal(em._targets[1], output2[1]) + assert all([t.device.type == available_device for t in em._predictions + em._targets]) + assert torch.equal(em._predictions[0].cpu(), output1[0].cpu()) + assert torch.equal(em._predictions[1].cpu(), output2[0].cpu()) + assert torch.equal(em._targets[0].cpu(), output1[1].cpu()) + assert torch.equal(em._targets[1].cpu(), output2[1].cpu()) assert em.compute() == 0.0