diff --git a/tests/ignite/metrics/test_average_precision.py b/tests/ignite/metrics/test_average_precision.py index b84f578111e5..aca1cd7474b2 100644 --- a/tests/ignite/metrics/test_average_precision.py +++ b/tests/ignite/metrics/test_average_precision.py @@ -82,9 +82,9 @@ def test_data_binary_and_multilabel(request): @pytest.mark.parametrize("n_times", range(5)) -def test_binary_and_multilabel_inputs(n_times, test_data_binary_and_multilabel): +def test_binary_and_multilabel_inputs(n_times, available_device, test_data_binary_and_multilabel): y_pred, y, batch_size = test_data_binary_and_multilabel - ap = AveragePrecision() + ap = AveragePrecision(device=available_device) ap.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 @@ -115,7 +115,9 @@ def test_data_integration_binary_and_multilabel(request): @pytest.mark.parametrize("n_times", range(5)) -def test_integration_binary_and_mulitlabel_inputs(n_times, test_data_integration_binary_and_multilabel): +def test_integration_binary_and_mulitlabel_inputs( + n_times, available_device, test_data_integration_binary_and_multilabel +): y_pred, y, batch_size = test_data_integration_binary_and_multilabel def update_fn(engine, batch): @@ -126,7 +128,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) - ap_metric = AveragePrecision() + ap_metric = AveragePrecision(device=available_device) ap_metric.attach(engine, "ap") np_y = y.numpy()