@@ -106,7 +106,15 @@ def _create_dist_context(dist_info, lrank):
106106
107107def _destroy_dist_context ():
108108
109+ if dist .get_rank () == 0 :
110+ # To support Python 3.7; Otherwise we could do `.unlink(missing_ok=True)`
111+ try :
112+ Path ("/tmp/free_port" ).unlink ()
113+ except FileNotFoundError :
114+ pass
115+
109116 dist .barrier ()
117+
110118 dist .destroy_process_group ()
111119
112120 from ignite .distributed .utils import _SerialModel , _set_model
@@ -360,7 +368,6 @@ def gloo_hvd_executor():
360368
361369
362370@pytest .fixture (
363- scope = "module" ,
364371 params = [
365372 pytest .param ("nccl" , marks = [pytest .mark .distributed , skip_if_has_not_native_dist_support , skip_if_no_gpu ]),
366373 pytest .param ("gloo_cpu" , marks = [pytest .mark .distributed , skip_if_has_not_native_dist_support ]),
@@ -430,17 +437,59 @@ def distributed(request, local_rank, world_size):
430437 elif request .param in ("single_device_xla" , "xla_nprocs" ):
431438 request .node .stash [is_xla_stash_key ] = True
432439 request .node .stash [is_xla_single_device_stash_key ] = request .param == "single_device_xla"
433- yield None
440+ yield { "xla_index" : - 1 } if request . param == "xla_nprocs" else None
434441 else :
435442 raise RuntimeError (f"Invalid parameter value for `distributed` fixture, given { request .param } " )
436443
437444
438445@pytest .hookimpl
439446def pytest_pyfunc_call (pyfuncitem : pytest .Function ) -> None :
440447 if pyfuncitem .stash .get (is_horovod_stash_key , False ):
441- nproc = 4 if not torch .cuda .is_available () else torch .cuda .device_count ()
442- pyfuncitem .obj = functools .partial (_gloo_hvd_execute , pyfuncitem .obj , (), np = nproc )
448+
449+ def testfunc_wrapper (test_func , ** kwargs ):
450+ def hvd_worker ():
451+ import horovod .torch as hvd
452+
453+ hvd .init ()
454+ lrank = hvd .local_rank ()
455+ if torch .cuda .is_available ():
456+ torch .cuda .set_device (lrank )
457+
458+ test_func (** kwargs )
459+
460+ hvd .shutdown ()
461+
462+ try :
463+ # old API
464+ from horovod .run .runner import run
465+ except ImportError :
466+ # new API: https://github.com/horovod/horovod/pull/2099
467+ from horovod import run
468+
469+ nproc = 4 if not torch .cuda .is_available () else torch .cuda .device_count ()
470+ hvd_kwargs = dict (use_gloo = True , num_proc = nproc )
471+ run (hvd_worker , ** hvd_kwargs )
472+
473+ pyfuncitem .obj = functools .partial (testfunc_wrapper , pyfuncitem .obj )
443474
444475 elif pyfuncitem .stash .get (is_xla_stash_key , False ) and not pyfuncitem .stash [is_xla_single_device_stash_key ]:
445- n = int (os .environ ["NUM_TPU_WORKERS" ])
446- pyfuncitem .obj = functools .partial (_xla_execute , pyfuncitem .obj , (), n )
476+
477+ def testfunc_wrapper (testfunc , ** kwargs ):
478+ def xla_worker (index , fn ):
479+ import torch_xla .core .xla_model as xm
480+
481+ kwargs ["distributed" ]["xla_index" ] = index
482+ xm .rendezvous ("init" )
483+ fn (** kwargs )
484+
485+ import torch_xla .distributed .xla_multiprocessing as xmp
486+
487+ spawn_kwargs = {"nprocs" : int (os .environ ["NUM_TPU_WORKERS" ])}
488+ if "COLAB_TPU_ADDR" in os .environ :
489+ spawn_kwargs ["start_method" ] = "fork"
490+ try :
491+ xmp .spawn (xla_worker , args = (testfunc ,), ** spawn_kwargs )
492+ except SystemExit as ex_ :
493+ assert ex_ .code == 0 , "Didn't successfully exit in XLA test"
494+
495+ pyfuncitem .obj = functools .partial (testfunc_wrapper , pyfuncitem .obj )
0 commit comments