diff --git a/tests/ignite/metrics/test_epoch_metric.py b/tests/ignite/metrics/test_epoch_metric.py index 70296d484180..bb066d448513 100644 --- a/tests/ignite/metrics/test_epoch_metric.py +++ b/tests/ignite/metrics/test_epoch_metric.py @@ -159,34 +159,36 @@ def _test_distrib_integration(device=None): device = idist.device() if idist.device().type != "xla" else "cpu" rank = idist.get_rank() - torch.manual_seed(12) + torch.manual_seed(12 + rank) - n_iters = 60 - s = 16 + n_iters = 3 + batch_size = 2 n_classes = 7 - offset = n_iters * s - y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),), device=device) - y_preds = torch.rand(offset * idist.get_world_size(), n_classes, device=device) + y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,), device=device) + y_preds = torch.rand(n_iters * batch_size, n_classes, device=device) def update(engine, i): return ( - y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :], - y_true[i * s + rank * offset : (i + 1) * s + rank * offset], + y_preds[i * batch_size : (i + 1) * batch_size, :], + y_true[i * batch_size : (i + 1) * batch_size], ) engine = Engine(update) def assert_data_fn(all_preds, all_targets): - assert all_preds.equal(y_preds), f"{all_preds.shape} vs {y_preds.shape}" - assert all_targets.equal(y_true), f"{all_targets.shape} vs {y_true.shape}" return (all_preds.argmax(dim=1) == all_targets).sum().item() ep_metric = EpochMetric(assert_data_fn, check_compute_fn=False, device=device) ep_metric.attach(engine, "epm") data = list(range(n_iters)) + engine.run(data=data, max_epochs=3) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + assert engine.state.metrics["epm"] == (y_preds.argmax(dim=1) == y_true).sum().item() diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 860c2820db7f..94f4fc160ef2 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -430,22 +430,18 @@ def _test_distrib_integration_multiclass(device): from ignite.engine import Engine - rank = idist.get_rank() - torch.manual_seed(12) - def _test(average, n_epochs, metric_device): n_iters = 60 - s = 16 + batch_size = 16 n_classes = 7 - offset = n_iters * s - y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) - y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) + y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(n_iters * batch_size, n_classes).to(device) def update(engine, i): return ( - y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :], - y_true[i * s + rank * offset : (i + 1) * s + rank * offset], + y_preds[i * batch_size : (i + 1) * batch_size, :], + y_true[i * batch_size : (i + 1) * batch_size], ) engine = Engine(update) @@ -457,6 +453,9 @@ def update(engine, i): data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + assert "re" in engine.state.metrics assert re._updated is True res = engine.state.metrics["re"] @@ -475,7 +474,9 @@ def update(engine, i): metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) - for _ in range(2): + rank = idist.get_rank() + for i in range(2): + torch.manual_seed(12 + rank + i) for metric_device in metric_devices: _test(average=False, n_epochs=1, metric_device=metric_device) _test(average=False, n_epochs=2, metric_device=metric_device) @@ -491,22 +492,20 @@ def _test_distrib_integration_multilabel(device): from ignite.engine import Engine - rank = idist.get_rank() torch.manual_seed(12) def _test(average, n_epochs, metric_device): n_iters = 60 - s = 16 + batch_size = 16 n_classes = 7 - offset = n_iters * s - y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) - y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) + y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device) + y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device) def update(engine, i): return ( - y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, ...], - y_true[i * s + rank * offset : (i + 1) * s + rank * offset, ...], + y_preds[i * batch_size : (i + 1) * batch_size, ...], + y_true[i * batch_size : (i + 1) * batch_size, ...], ) engine = Engine(update) @@ -518,6 +517,9 @@ def update(engine, i): data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + assert "re" in engine.state.metrics assert re._updated is True res = engine.state.metrics["re"] @@ -540,7 +542,9 @@ def update(engine, i): metric_devices = ["cpu"] if device.type != "xla": metric_devices.append(idist.device()) - for _ in range(2): + rank = idist.get_rank() + for i in range(2): + torch.manual_seed(12 + rank + i) for metric_device in metric_devices: _test(average=False, n_epochs=1, metric_device=metric_device) _test(average=False, n_epochs=2, metric_device=metric_device)