diff --git a/tests/ignite/metrics/test_mean_absolute_error.py b/tests/ignite/metrics/test_mean_absolute_error.py index ab9b0da8810f..6e945cae198b 100644 --- a/tests/ignite/metrics/test_mean_absolute_error.py +++ b/tests/ignite/metrics/test_mean_absolute_error.py @@ -29,8 +29,9 @@ def test_case(request): @pytest.mark.parametrize("n_times", range(5)) -def test_compute(n_times, test_case): - mae = MeanAbsoluteError() +def test_compute(n_times, test_case, available_device): + mae = MeanAbsoluteError(device=available_device) + assert mae._device == torch.device(available_device) y_pred, y, batch_size = test_case diff --git a/tests/ignite/metrics/test_mean_pairwise_distance.py b/tests/ignite/metrics/test_mean_pairwise_distance.py index 0a53f48193ea..6b38883e3039 100644 --- a/tests/ignite/metrics/test_mean_pairwise_distance.py +++ b/tests/ignite/metrics/test_mean_pairwise_distance.py @@ -29,8 +29,9 @@ def test_case(request): @pytest.mark.parametrize("n_times", range(5)) -def test_compute(n_times, test_case): - mpd = MeanPairwiseDistance() +def test_compute(n_times, test_case, available_device): + mpd = MeanPairwiseDistance(device=available_device) + assert mpd._device == torch.device(available_device) y_pred, y, batch_size = test_case