Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,15 @@ def _create_dist_context(dist_info, lrank):

def _destroy_dist_context():

if dist.get_rank() == 0:
# To support Python 3.7; Otherwise we could do `.unlink(missing_ok=True)`
try:
Path("/tmp/free_port").unlink()
except FileNotFoundError:
pass

dist.barrier()

dist.destroy_process_group()

from ignite.distributed.utils import _SerialModel, _set_model
Expand Down Expand Up @@ -360,7 +368,6 @@ def gloo_hvd_executor():


@pytest.fixture(
scope="module",
params=[
pytest.param("nccl", marks=[pytest.mark.distributed, skip_if_has_not_native_dist_support, skip_if_no_gpu]),
pytest.param("gloo_cpu", marks=[pytest.mark.distributed, skip_if_has_not_native_dist_support]),
Expand Down Expand Up @@ -430,17 +437,59 @@ def distributed(request, local_rank, world_size):
elif request.param in ("single_device_xla", "xla_nprocs"):
request.node.stash[is_xla_stash_key] = True
request.node.stash[is_xla_single_device_stash_key] = request.param == "single_device_xla"
yield None
yield {"xla_index": -1} if request.param == "xla_nprocs" else None
else:
raise RuntimeError(f"Invalid parameter value for `distributed` fixture, given {request.param}")


@pytest.hookimpl
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
if pyfuncitem.stash.get(is_horovod_stash_key, False):
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
pyfuncitem.obj = functools.partial(_gloo_hvd_execute, pyfuncitem.obj, (), np=nproc)

def testfunc_wrapper(test_func, **kwargs):
def hvd_worker():
import horovod.torch as hvd

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

test_func(**kwargs)

hvd.shutdown()

try:
# old API
from horovod.run.runner import run
except ImportError:
# new API: https://github.com/horovod/horovod/pull/2099
from horovod import run

nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
hvd_kwargs = dict(use_gloo=True, num_proc=nproc)
run(hvd_worker, **hvd_kwargs)

pyfuncitem.obj = functools.partial(testfunc_wrapper, pyfuncitem.obj)

elif pyfuncitem.stash.get(is_xla_stash_key, False) and not pyfuncitem.stash[is_xla_single_device_stash_key]:
n = int(os.environ["NUM_TPU_WORKERS"])
pyfuncitem.obj = functools.partial(_xla_execute, pyfuncitem.obj, (), n)

def testfunc_wrapper(testfunc, **kwargs):
def xla_worker(index, fn):
import torch_xla.core.xla_model as xm

kwargs["distributed"]["xla_index"] = index
xm.rendezvous("init")
fn(**kwargs)

import torch_xla.distributed.xla_multiprocessing as xmp

spawn_kwargs = {"nprocs": int(os.environ["NUM_TPU_WORKERS"])}
if "COLAB_TPU_ADDR" in os.environ:
spawn_kwargs["start_method"] = "fork"
try:
xmp.spawn(xla_worker, args=(testfunc,), **spawn_kwargs)
except SystemExit as ex_:
assert ex_.code == 0, "Didn't successfully exit in XLA test"

pyfuncitem.obj = functools.partial(testfunc_wrapper, pyfuncitem.obj)