From 237d8deed59746ffa563e3fa2cdf6978765e2f7d Mon Sep 17 00:00:00 2001 From: Marc Bresson Date: Thu, 24 Aug 2023 18:19:47 +0200 Subject: [PATCH 1/4] feat: add compatibility with uint8 --- ignite/metrics/ssim.py | 9 +++++++ tests/ignite/metrics/test_ssim.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 03ff1ce17162..34701715a49a 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Sequence, Union import torch @@ -98,6 +99,7 @@ def __init__( super(SSIM, self).__init__(output_transform=output_transform, device=device) self.gaussian = gaussian + self.data_range = data_range self.c1 = (k1 * data_range) ** 2 self.c2 = (k2 * data_range) ** 2 self.pad_h = (self.kernel_size[0] - 1) // 2 @@ -157,6 +159,13 @@ def update(self, output: Sequence[torch.Tensor]) -> None: f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}." ) + if y_pred.dtype == torch.uint8 or y.dtype == torch.uint8: + if self.data_range != 255: + warnings.warn("dtypes of the input tensors are torch.uint8 but data range is not set to 255.", RuntimeWarning) + + y_pred = y_pred.to(dtype=torch.float32) + y = y.to(dtype=torch.float32) + channel = y_pred.size(1) if len(self._kernel.shape) < 4: self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device) diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index 45ddccbd1c68..3c6fa562274f 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -139,6 +139,50 @@ def test_cuda_ssim_dtypes(available_device, dtype, precision): test_ssim(available_device, (12, 3, 28, 28), 11, True, False, dtype=dtype, precision=precision) +@pytest.mark.parametrize( + "shape, kernel_size, gaussian, use_sample_covariance", + [[(8, 3, 224, 224), 7, False, True], [(12, 3, 28, 28), 11, True, False]], +) +def test_ssim_uint8(available_device, shape, kernel_size, gaussian, use_sample_covariance): + y_pred = torch.randint(0, 255, shape, device=available_device, dtype=torch.uint8) + y = (y_pred * 0.8).to(dtype=torch.uint8) + + sigma = 1.5 + data_range = 255 + ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device) + ssim.update((y_pred, y)) + ignite_ssim = ssim.compute() + + skimg_pred = y_pred.cpu().numpy() + skimg_y = (skimg_pred * 0.8).astype(np.uint8) + skimg_ssim = ski_ssim( + skimg_pred, + skimg_y, + win_size=kernel_size, + sigma=sigma, + channel_axis=1, + gaussian_weights=gaussian, + data_range=data_range, + use_sample_covariance=use_sample_covariance, + ) + + assert isinstance(ignite_ssim, float) + assert np.allclose(ignite_ssim, skimg_ssim, atol=1e-5) + + +def test_ssim_uint8_warning(available_device): + shape = (7, 3, 256, 256) + y_pred = torch.randint(0, 255, shape, device=available_device, dtype=torch.uint8) + y = (y_pred * 0.8).to(dtype=torch.uint8) + + sigma = 1.5 + data_range = 1.0 + ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device) + + with pytest.warns(RuntimeWarning, match=r"dtypes of the input tensors are torch.uint8 but data range is not set to 255."): + ssim.update((y_pred, y)) + + @pytest.mark.parametrize("metric_device", ["cpu", "process_device"]) def test_distrib_integration(distributed, metric_device): from ignite.engine import Engine From e04cdbb4a06f4427a244e69eec6998daf43e3714 Mon Sep 17 00:00:00 2001 From: Marc Bresson Date: Fri, 25 Aug 2023 09:21:11 +0200 Subject: [PATCH 2/4] style: format using the run_code_style script --- ignite/metrics/ssim.py | 4 +++- tests/ignite/metrics/test_ssim.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 34701715a49a..ee99fa744339 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -161,7 +161,9 @@ def update(self, output: Sequence[torch.Tensor]) -> None: if y_pred.dtype == torch.uint8 or y.dtype == torch.uint8: if self.data_range != 255: - warnings.warn("dtypes of the input tensors are torch.uint8 but data range is not set to 255.", RuntimeWarning) + warnings.warn( + "dtypes of the input tensors are torch.uint8 but data range is not set to 255.", RuntimeWarning + ) y_pred = y_pred.to(dtype=torch.float32) y = y.to(dtype=torch.float32) diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index 3c6fa562274f..32c01a6f1e1a 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -179,7 +179,9 @@ def test_ssim_uint8_warning(available_device): data_range = 1.0 ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device) - with pytest.warns(RuntimeWarning, match=r"dtypes of the input tensors are torch.uint8 but data range is not set to 255."): + with pytest.warns( + RuntimeWarning, match=r"dtypes of the input tensors are torch.uint8 but data range is not set to 255." + ): ssim.update((y_pred, y)) From 67edd84d2b85376e975fce1cdc043be00923eaa2 Mon Sep 17 00:00:00 2001 From: Marc Bresson Date: Wed, 13 Sep 2023 09:17:23 +0200 Subject: [PATCH 3/4] refactor: delete warning and independantly convert y and y_pred to fp --- ignite/metrics/ssim.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index bbc096db74b2..6824c0b3f374 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -158,14 +158,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None: f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}." ) - if y_pred.dtype == torch.uint8 or y.dtype == torch.uint8: - if self.data_range != 255: - warnings.warn( - "dtypes of the input tensors are torch.uint8 but data range is not set to 255.", RuntimeWarning - ) - - y_pred = y_pred.to(dtype=torch.float32) - y = y.to(dtype=torch.float32) + # converts potential integer tensor to fp + if not y.is_floating_point(): + y = y.float() + if not y_pred.is_floating_point(): + y_pred = y_pred.float() nb_channel = y_pred.size(1) if self._kernel is None or self._kernel.shape[0] != nb_channel: From 2cf4fef8874846343ba3e4634958fa5a638642e3 Mon Sep 17 00:00:00 2001 From: Marc Bresson Date: Wed, 13 Sep 2023 09:19:59 +0200 Subject: [PATCH 4/4] feat: remove uint8 warning test --- tests/ignite/metrics/test_ssim.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index 85e5283efadb..e81d9abf962b 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -253,21 +253,6 @@ def test_ssim_uint8(available_device, shape, kernel_size, gaussian, use_sample_c assert np.allclose(ignite_ssim, skimg_ssim, atol=1e-5) -def test_ssim_uint8_warning(available_device): - shape = (7, 3, 256, 256) - y_pred = torch.randint(0, 255, shape, device=available_device, dtype=torch.uint8) - y = (y_pred * 0.8).to(dtype=torch.uint8) - - sigma = 1.5 - data_range = 1.0 - ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device) - - with pytest.warns( - RuntimeWarning, match=r"dtypes of the input tensors are torch.uint8 but data range is not set to 255." - ): - ssim.update((y_pred, y)) - - @pytest.mark.parametrize("metric_device", ["cpu", "process_device"]) def test_distrib_integration(distributed, metric_device): from ignite.engine import Engine