diff --git a/tests/ignite/metrics/test_cohen_kappa.py b/tests/ignite/metrics/test_cohen_kappa.py index beb25585dc75..1649c21668e4 100644 --- a/tests/ignite/metrics/test_cohen_kappa.py +++ b/tests/ignite/metrics/test_cohen_kappa.py @@ -85,9 +85,11 @@ def test_data_binary(request): @pytest.mark.parametrize("n_times", range(5)) @pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) -def test_binary_input(n_times, weights, test_data_binary): +def test_binary_input(n_times, weights, test_data_binary, available_device): y_pred, y, batch_size = test_data_binary - ck = CohenKappa(weights) + ck = CohenKappa(weights=weights, device=available_device) + assert ck._device == torch.device(available_device) + ck.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 @@ -135,7 +137,7 @@ def test_data_integration_binary(request): @pytest.mark.parametrize("n_times", range(5)) @pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) -def test_integration_binary_input(n_times, weights, test_data_integration_binary): +def test_integration_binary_input(n_times, weights, test_data_integration_binary, available_device): y_pred, y, batch_size = test_data_integration_binary def update_fn(engine, batch): @@ -146,7 +148,9 @@ def update_fn(engine, batch): engine = Engine(update_fn) - ck_metric = CohenKappa(weights=weights) + ck_metric = CohenKappa(weights=weights, device=available_device) + assert ck_metric._device == torch.device(available_device) + ck_metric.attach(engine, "ck") np_y = y.numpy()