diff --git a/tests/ignite/metrics/test_hsic.py b/tests/ignite/metrics/test_hsic.py index 7eb4071b9a5a..c1b246eafdc9 100644 --- a/tests/ignite/metrics/test_hsic.py +++ b/tests/ignite/metrics/test_hsic.py @@ -87,10 +87,11 @@ def test_case(request) -> Tuple[Tensor, Tensor, int]: @pytest.mark.parametrize("n_times", range(3)) @pytest.mark.parametrize("sigma_x", [-1.0, 1.0]) @pytest.mark.parametrize("sigma_y", [-1.0, 1.0]) -def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int]): +def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int], available_device): x, y, batch_size = test_case - hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y) + hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=available_device) + assert hsic._device == torch.device(available_device) hsic.reset() @@ -109,8 +110,9 @@ def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tenso assert pytest.approx(expected_hsic, abs=2e-5) == hsic.compute() -def test_accumulator_detached(): - hsic = HSIC() +def test_accumulator_detached(available_device): + hsic = HSIC(device=available_device) + assert hsic._device == torch.device(available_device) x = torch.rand(10, 10, dtype=torch.float) y = torch.rand(10, 10, dtype=torch.float)