Skip to content

Commit 79c4235

Browse files
Fix some bugs in conftest.py, one in _destroy_dist_context and the other in distributed (#2788)
* Fix bug related to _destroy_dist_context * Fix bug in new distributed fixture * Fix a bug related to distributed fixture * Fix a tiny bug in distributed fixture's XLA setting * Attempt to fix a bug * Update tests/ignite/conftest.py * Update tests/ignite/conftest.py Co-authored-by: vfdev <[email protected]>
1 parent e0a2879 commit 79c4235

File tree

1 file changed

+55
-6
lines changed

1 file changed

+55
-6
lines changed

tests/ignite/conftest.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,15 @@ def _create_dist_context(dist_info, lrank):
106106

107107
def _destroy_dist_context():
108108

109+
if dist.get_rank() == 0:
110+
# To support Python 3.7; Otherwise we could do `.unlink(missing_ok=True)`
111+
try:
112+
Path("/tmp/free_port").unlink()
113+
except FileNotFoundError:
114+
pass
115+
109116
dist.barrier()
117+
110118
dist.destroy_process_group()
111119

112120
from ignite.distributed.utils import _SerialModel, _set_model
@@ -360,7 +368,6 @@ def gloo_hvd_executor():
360368

361369

362370
@pytest.fixture(
363-
scope="module",
364371
params=[
365372
pytest.param("nccl", marks=[pytest.mark.distributed, skip_if_has_not_native_dist_support, skip_if_no_gpu]),
366373
pytest.param("gloo_cpu", marks=[pytest.mark.distributed, skip_if_has_not_native_dist_support]),
@@ -430,17 +437,59 @@ def distributed(request, local_rank, world_size):
430437
elif request.param in ("single_device_xla", "xla_nprocs"):
431438
request.node.stash[is_xla_stash_key] = True
432439
request.node.stash[is_xla_single_device_stash_key] = request.param == "single_device_xla"
433-
yield None
440+
yield {"xla_index": -1} if request.param == "xla_nprocs" else None
434441
else:
435442
raise RuntimeError(f"Invalid parameter value for `distributed` fixture, given {request.param}")
436443

437444

438445
@pytest.hookimpl
439446
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
440447
if pyfuncitem.stash.get(is_horovod_stash_key, False):
441-
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
442-
pyfuncitem.obj = functools.partial(_gloo_hvd_execute, pyfuncitem.obj, (), np=nproc)
448+
449+
def testfunc_wrapper(test_func, **kwargs):
450+
def hvd_worker():
451+
import horovod.torch as hvd
452+
453+
hvd.init()
454+
lrank = hvd.local_rank()
455+
if torch.cuda.is_available():
456+
torch.cuda.set_device(lrank)
457+
458+
test_func(**kwargs)
459+
460+
hvd.shutdown()
461+
462+
try:
463+
# old API
464+
from horovod.run.runner import run
465+
except ImportError:
466+
# new API: https://github.com/horovod/horovod/pull/2099
467+
from horovod import run
468+
469+
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
470+
hvd_kwargs = dict(use_gloo=True, num_proc=nproc)
471+
run(hvd_worker, **hvd_kwargs)
472+
473+
pyfuncitem.obj = functools.partial(testfunc_wrapper, pyfuncitem.obj)
443474

444475
elif pyfuncitem.stash.get(is_xla_stash_key, False) and not pyfuncitem.stash[is_xla_single_device_stash_key]:
445-
n = int(os.environ["NUM_TPU_WORKERS"])
446-
pyfuncitem.obj = functools.partial(_xla_execute, pyfuncitem.obj, (), n)
476+
477+
def testfunc_wrapper(testfunc, **kwargs):
478+
def xla_worker(index, fn):
479+
import torch_xla.core.xla_model as xm
480+
481+
kwargs["distributed"]["xla_index"] = index
482+
xm.rendezvous("init")
483+
fn(**kwargs)
484+
485+
import torch_xla.distributed.xla_multiprocessing as xmp
486+
487+
spawn_kwargs = {"nprocs": int(os.environ["NUM_TPU_WORKERS"])}
488+
if "COLAB_TPU_ADDR" in os.environ:
489+
spawn_kwargs["start_method"] = "fork"
490+
try:
491+
xmp.spawn(xla_worker, args=(testfunc,), **spawn_kwargs)
492+
except SystemExit as ex_:
493+
assert ex_.code == 0, "Didn't successfully exit in XLA test"
494+
495+
pyfuncitem.obj = functools.partial(testfunc_wrapper, pyfuncitem.obj)

0 commit comments

Comments
 (0)