Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 113 additions & 2 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os
import shutil
import sys
Expand All @@ -9,6 +10,8 @@
import torch
import torch.distributed as dist

import ignite.distributed as idist


@pytest.fixture()
def dirname():
Expand Down Expand Up @@ -50,7 +53,7 @@ def func():
yield func


@pytest.fixture()
@pytest.fixture(scope="module")
def local_rank(worker_id):
"""use a different account in each xdist worker"""

Expand All @@ -68,7 +71,7 @@ def local_rank(worker_id):
del os.environ["LOCAL_RANK"]


@pytest.fixture()
@pytest.fixture(scope="module")
def world_size():

remove_env_var = False
Expand Down Expand Up @@ -333,3 +336,111 @@ def _gloo_hvd_execute(func, args, np=1, do_init=False):
@pytest.fixture()
def gloo_hvd_executor():
yield _gloo_hvd_execute


skip_if_no_gpu = pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
skip_if_has_not_native_dist_support = pytest.mark.skipif(
not idist.has_native_dist_support, reason="Skip if no native dist support"
)
skip_if_has_not_xla_support = pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
skip_if_has_not_horovod_support = pytest.mark.skipif(
not idist.has_hvd_support, reason="Skip if no Horovod dist support"
)


# Unlike other backends, Horovod and multi-process XLA run user code by
# providing a utility function which accepts user code as a callable argument.
# To keep distributed tests backend-agnostic, we mark Horovod and multi-process XLA
# tests during fixture preparation and replace their function with the proper one
# just before running the test. PyTest stash is a safe way to share state between
# different stages of tool runtime and we use it to mark the tests.
is_horovod_stash_key = pytest.StashKey[bool]()
is_xla_stash_key = pytest.StashKey[bool]()
is_xla_single_device_stash_key = pytest.StashKey[bool]()


@pytest.fixture(
scope="module",
params=[
pytest.param("nccl", marks=[pytest.mark.distributed, skip_if_has_not_native_dist_support, skip_if_no_gpu]),
pytest.param("gloo_cpu", marks=[pytest.mark.distributed, skip_if_has_not_native_dist_support]),
pytest.param("gloo", marks=[pytest.mark.distributed, skip_if_has_not_native_dist_support, skip_if_no_gpu]),
pytest.param(
"horovod",
marks=[
pytest.mark.distributed,
skip_if_has_not_horovod_support,
pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc"),
],
),
pytest.param(
"single_device_xla",
marks=[
pytest.mark.tpu,
skip_if_has_not_xla_support,
pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars"),
],
),
pytest.param(
"xla_nprocs",
marks=[
pytest.mark.tpu,
skip_if_has_not_xla_support,
pytest.mark.skipif(
"NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars"
),
],
),
],
)
def distributed(request, local_rank, world_size):
if request.param in ("nccl", "gloo_cpu", "gloo"):
if "gloo" in request.param and sys.platform.startswith("win"):
temp_file = tempfile.NamedTemporaryFile(delete=False)
# can't use backslashes in f-strings
backslash = "\\"
init_method = f'file:///{temp_file.name.replace(backslash, "/")}'
else:
temp_file = None
free_port = _setup_free_port(local_rank)
init_method = f"tcp://localhost:{free_port}"

dist_info = {
"world_size": world_size,
"rank": local_rank,
"init_method": init_method,
}

if request.param == "nccl":
dist_info["backend"] = "nccl"
else:
dist_info["backend"] = "gloo"
from datetime import timedelta

dist_info["timeout"] = timedelta(seconds=60)
yield _create_dist_context(dist_info, local_rank)
_destroy_dist_context()
if temp_file:
temp_file.close()

elif request.param == "horovod":
request.node.stash[is_horovod_stash_key] = True
yield None

elif request.param in ("single_device_xla", "xla_nprocs"):
request.node.stash[is_xla_stash_key] = True
request.node.stash[is_xla_single_device_stash_key] = request.param == "single_device_xla"
yield None
else:
raise RuntimeError(f"Invalid parameter value for `distributed` fixture, given {request.param}")


@pytest.hookimpl
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
if pyfuncitem.stash.get(is_horovod_stash_key, False):
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
pyfuncitem.obj = functools.partial(_gloo_hvd_execute, pyfuncitem.obj, (), np=nproc)

elif pyfuncitem.stash.get(is_xla_stash_key, False) and not pyfuncitem.stash[is_xla_single_device_stash_key]:
n = int(os.environ["NUM_TPU_WORKERS"])
pyfuncitem.obj = functools.partial(_xla_execute, pyfuncitem.obj, (), n)
101 changes: 8 additions & 93 deletions tests/ignite/metrics/test_precision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import warnings

import pytest
Expand Down Expand Up @@ -424,7 +423,7 @@ def test_incorrect_y_classes(average):
assert pr._updated is False


def _test_distrib_integration_multiclass(device):
def test_distrib_integration_multiclass(distributed):
from ignite.engine import Engine

rank = idist.get_rank()
Expand Down Expand Up @@ -470,6 +469,7 @@ def update(engine, i):
assert pytest.approx(res) == true_res

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(2):
Expand All @@ -484,7 +484,7 @@ def update(engine, i):
_test(average="micro", n_epochs=2, metric_device=metric_device)


def _test_distrib_integration_multilabel(device):
def test_distrib_integration_multilabel(distributed):

from ignite.engine import Engine

Expand Down Expand Up @@ -535,6 +535,7 @@ def update(engine, i):
assert precision_score(np_y_true, np_y_preds, average=sk_average_parameter) == pytest.approx(res)

metric_devices = ["cpu"]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(2):
Expand All @@ -551,7 +552,7 @@ def update(engine, i):
_test(average="samples", n_epochs=2, metric_device=metric_device)


def _test_distrib_accumulator_device(device):
def test_distrib_accumulator_device(distributed):
# Binary accuracy on input of shape (N, 1) or (N, )

def _test(average, metric_device):
Expand Down Expand Up @@ -582,6 +583,7 @@ def _test(average, metric_device):
f"{type(metric_device)}:{metric_device}"

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
Expand All @@ -591,7 +593,7 @@ def _test(average, metric_device):
_test("weighted", metric_device=metric_device)


def _test_distrib_multilabel_accumulator_device(device):
def test_distrib_multilabel_accumulator_device(distributed):
# Multiclass input data of shape (N, ) and (N, C)

def _test(average, metric_device):
Expand Down Expand Up @@ -621,6 +623,7 @@ def _test(average, metric_device):
f"{type(metric_device)}:{metric_device}"

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
Expand All @@ -629,91 +632,3 @@ def _test(average, metric_device):
_test("micro", metric_device=metric_device)
_test("weighted", metric_device=metric_device)
_test("samples", metric_device=metric_device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

device = idist.device()
_test_distrib_integration_multiclass(device)
_test_distrib_integration_multilabel(device)
_test_distrib_accumulator_device(device)
_test_distrib_multilabel_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

device = idist.device()
_test_distrib_integration_multiclass(device)
_test_distrib_integration_multilabel(device)
_test_distrib_accumulator_device(device)
_test_distrib_multilabel_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):

device = idist.device()
_test_distrib_integration_multiclass(device)
_test_distrib_integration_multilabel(device)
_test_distrib_accumulator_device(device)
_test_distrib_multilabel_accumulator_device(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

device = idist.device()
_test_distrib_integration_multiclass(device)
_test_distrib_integration_multilabel(device)
_test_distrib_accumulator_device(device)
_test_distrib_multilabel_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_integration_multiclass(device)
_test_distrib_integration_multilabel(device)
_test_distrib_accumulator_device(device)
_test_distrib_multilabel_accumulator_device(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration_multiclass(device)
_test_distrib_integration_multilabel(device)
_test_distrib_accumulator_device(device)
_test_distrib_multilabel_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)