Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/ignite/metrics/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,6 @@ def _test_distrib_integration_multiclass(device):

from ignite.engine import Engine

rank = idist.get_rank()

def _test(average, n_epochs, metric_device):
n_iters = 60
batch_size = 16
Expand Down Expand Up @@ -477,6 +475,7 @@ def update(engine, i):
if device.type != "xla":
metric_devices.append(idist.device())
for i in range(2):
rank = idist.get_rank()
torch.manual_seed(12 + rank + i)
for metric_device in metric_devices:
_test(average=False, n_epochs=1, metric_device=metric_device)
Expand All @@ -493,7 +492,6 @@ def _test_distrib_integration_multilabel(device):

from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12)

def _test(average, n_epochs, metric_device):
Expand Down Expand Up @@ -545,6 +543,7 @@ def update(engine, i):
if device.type != "xla":
metric_devices.append(idist.device())
for i in range(2):
rank = idist.get_rank()
torch.manual_seed(12 + rank + i)
for metric_device in metric_devices:
_test(average=False, n_epochs=1, metric_device=metric_device)
Expand Down