Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ignite/metrics/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_batchwise_psnr", "_num_examples")
def compute(self) -> torch.Tensor:
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("PSNR must have at least one example before it can be computed.")
return self._sum_of_batchwise_psnr / self._num_examples
return (self._sum_of_batchwise_psnr / self._num_examples).item()
4 changes: 2 additions & 2 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_ssim", "_num_examples")
def compute(self) -> torch.Tensor:
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("SSIM must have at least one example before it can be computed.")
return self._sum_of_ssim / self._num_examples
return (self._sum_of_ssim / self._num_examples).item()
10 changes: 10 additions & 0 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
import ignite.distributed as idist


@pytest.fixture(
params=[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
]
)
def available_device(request):
return request.param


@pytest.fixture()
def dirname():
path = Path(tempfile.mkdtemp())
Expand Down
207 changes: 52 additions & 155 deletions tests/ignite/metrics/test_psnr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import numpy as np
import pytest
import torch
Expand All @@ -11,8 +9,6 @@
from ignite.metrics import PSNR
from ignite.utils import manual_seed

from tests.ignite import cpu_and_maybe_cuda


def test_zero_div():
psnr = PSNR(1.0)
Expand All @@ -32,9 +28,32 @@ def test_invalid_psnr():
psnr.update((y_pred, y.squeeze(dim=0)))


def _test_psnr(y_pred, y, data_range, device):
psnr = PSNR(data_range=data_range, device=device)
psnr.update((y_pred, y))
@pytest.fixture(params=["float", "YCbCr", "uint8", "NHW shape"])
def test_data(request, available_device):
manual_seed(42)
if request.param == "float":
y_pred = torch.rand(8, 3, 28, 28, device=available_device)
y = y_pred * 0.8
elif request.param == "YCbCr":
y_pred = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=available_device)
y = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=available_device)
elif request.param == "uint8":
y_pred = torch.randint(0, 256, (4, 3, 16, 16), dtype=torch.uint8, device=available_device)
y = (y_pred * 0.8).to(torch.uint8)
elif request.param == "NHW shape":
y_pred = torch.rand(8, 28, 28, device=available_device)
y = y_pred * 0.8
else:
raise ValueError(f"Wrong fixture parameter, given {request.param}")
return (y_pred, y)


def test_psnr(test_data, available_device):
y_pred, y = test_data
data_range = (y.max() - y.min()).cpu().item()

psnr = PSNR(data_range=data_range, device=available_device)
psnr.update(test_data)
psnr_compute = psnr.compute()

np_y_pred = y_pred.cpu().numpy()
Expand All @@ -43,43 +62,9 @@ def _test_psnr(y_pred, y, data_range, device):
for np_y_pred_, np_y_ in zip(np_y_pred, np_y):
np_psnr += ski_psnr(np_y_, np_y_pred_, data_range=data_range)

assert torch.gt(psnr_compute, 0.0)
assert isinstance(psnr_compute, torch.Tensor)
assert psnr_compute.dtype == torch.float64
assert psnr_compute.device.type == torch.device(device).type
assert np.allclose(psnr_compute.cpu().numpy(), np_psnr / np_y.shape[0])


@pytest.mark.parametrize("device", cpu_and_maybe_cuda())
def test_psnr(device):

# test for float
manual_seed(42)
y_pred = torch.rand(8, 3, 28, 28, device=device)
y = y_pred * 0.8
data_range = (y.max() - y.min()).cpu().item()
_test_psnr(y_pred, y, data_range, device)

# test for YCbCr
manual_seed(42)
y_pred = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=device)
y = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=device)
data_range = (y.max() - y.min()).cpu().item()
_test_psnr(y_pred, y, data_range, device)

# test for uint8
manual_seed(42)
y_pred = torch.randint(0, 256, (4, 3, 16, 16), dtype=torch.uint8, device=device)
y = (y_pred * 0.8).to(torch.uint8)
data_range = (y.max() - y.min()).cpu().item()
_test_psnr(y_pred, y, data_range, device)

# test with NHW shape
manual_seed(42)
y_pred = torch.rand(8, 28, 28, device=device)
y = y_pred * 0.8
data_range = (y.max() - y.min()).cpu().item()
_test_psnr(y_pred, y, data_range, device)
assert psnr_compute > 0.0
assert isinstance(psnr_compute, float)
assert np.allclose(psnr_compute, np_psnr / np_y.shape[0])


def _test(
Expand Down Expand Up @@ -109,9 +94,9 @@ def update(engine, i):
y = idist.all_gather(y)
y_pred = idist.all_gather(y_pred)

assert "psnr" in engine.state.metrics
result = engine.state.metrics["psnr"]
assert result > 0.0
assert "psnr" in engine.state.metrics

if compute_y_channel:
np_y_pred = y_pred[:, 0, ...].cpu().numpy()
Expand All @@ -127,7 +112,9 @@ def update(engine, i):
assert np.allclose(result, np_psnr / np_y.shape[0], atol=atol)


def _test_distrib_input_float(device, atol=1e-8):
def test_distrib_input_float(distributed):
device = idist.device()

def get_test_cases():

y_pred = torch.rand(n_iters * batch_size, 2, 2, device=device)
Expand All @@ -143,12 +130,14 @@ def get_test_cases():
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 1, "cpu", n_iters, batch_size, atol=atol)
_test(y_pred, y, 1, "cpu", n_iters, batch_size, atol=1e-8)
if device.type != "xla":
_test(y_pred, y, 1, idist.device(), n_iters, batch_size, atol=atol)
_test(y_pred, y, 1, idist.device(), n_iters, batch_size, atol=1e-8)


def test_distrib_multilabel_input_YCbCr(distributed):
device = idist.device()

def _test_distrib_multilabel_input_YCbCr(device, atol=1e-8):
def get_test_cases():

y_pred = torch.randint(16, 236, (n_iters * batch_size, 1, 12, 12), dtype=torch.uint8, device=device)
Expand All @@ -171,13 +160,15 @@ def out_fn(x):
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 220, "cpu", n_iters, batch_size, atol, output_transform=out_fn, compute_y_channel=True)
_test(y_pred, y, 220, "cpu", n_iters, batch_size, atol=1e-8, output_transform=out_fn, compute_y_channel=True)
if device.type != "xla":
dev = idist.device()
_test(y_pred, y, 220, dev, n_iters, batch_size, atol, output_transform=out_fn, compute_y_channel=True)
_test(y_pred, y, 220, dev, n_iters, batch_size, atol=1e-8, output_transform=out_fn, compute_y_channel=True)


def _test_distrib_multilabel_input_uint8(device, atol=1e-8):
def test_distrib_multilabel_input_uint8(distributed):
device = idist.device()

def get_test_cases():

y_pred = torch.randint(0, 256, (n_iters * batch_size, 3, 16, 16), device=device, dtype=torch.uint8)
Expand All @@ -193,12 +184,14 @@ def get_test_cases():
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 100, "cpu", n_iters, batch_size, atol)
_test(y_pred, y, 100, "cpu", n_iters, batch_size, atol=1e-8)
if device.type != "xla":
_test(y_pred, y, 100, idist.device(), n_iters, batch_size, atol)
_test(y_pred, y, 100, idist.device(), n_iters, batch_size, atol=1e-8)


def _test_distrib_multilabel_input_NHW(device, atol=1e-8):
def test_distrib_multilabel_input_NHW(distributed):
device = idist.device()

def get_test_cases():

y_pred = torch.rand(n_iters * batch_size, 28, 28, device=device)
Expand All @@ -214,13 +207,13 @@ def get_test_cases():
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 10, "cpu", n_iters, batch_size, atol)
_test(y_pred, y, 10, "cpu", n_iters, batch_size, atol=1e-8)
if device.type != "xla":
_test(y_pred, y, 10, idist.device(), n_iters, batch_size, atol)
_test(y_pred, y, 10, idist.device(), n_iters, batch_size, atol=1e-8)


def _test_distrib_accumulator_device(device):

def test_distrib_accumulator_device(distributed):
device = idist.device()
metric_devices = [torch.device("cpu")]
if torch.device(device).type != "xla":
metric_devices.append(idist.device())
Expand All @@ -235,99 +228,3 @@ def _test_distrib_accumulator_device(device):
psnr.update((y_pred, y))
dev = psnr._sum_of_batchwise_psnr.device
assert dev == metric_device, f"{dev} vs {metric_device}"


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

device = idist.device()
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

device = idist.device()
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):

device = idist.device()
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

device = idist.device()
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():

device = idist.device()
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_input_float, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_multilabel_input_YCbCr, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_multilabel_input_uint8, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_multilabel_input_NHW, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True)
Loading