diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py index f24f33abb9db..4f8009b8898a 100644 --- a/tests/ignite/metrics/test_mean_average_precision.py +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -166,7 +166,12 @@ def test_distrib_integration(distributed, data_type): world_size = idist.get_world_size() device = idist.device() - def _test(metric_device): + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + def update(_, i): return ( y_preds[(2 * rank + i) * 10 : (2 * rank + i + 1) * 10], @@ -195,11 +200,5 @@ def update(_, i): y_true = y_true.transpose(1, -1).reshape(-1, 4) y_preds = y_preds.transpose(1, -1).reshape(-1, 4) - sklearn_mAP = average_precision_score(y_true.numpy(), y_preds.numpy()) + sklearn_mAP = average_precision_score(y_true.cpu().numpy(), y_preds.cpu().numpy()) assert np.allclose(sklearn_mAP, engine.state.metrics["mAP"]) - - metric_devices = [torch.device("cpu")] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - _test(metric_device)