From 78c30ddcabeb9e5557f434dc03af6965fe43f8ab Mon Sep 17 00:00:00 2001 From: crj1998 Date: Wed, 7 Dec 2022 05:16:55 +0000 Subject: [PATCH 1/2] align ssim with psnr --- ignite/metrics/ssim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 5a3a52fd8232..0c9a2351d521 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -180,7 +180,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_ssim", "_num_examples") - def compute(self) -> float: + def compute(self) -> torch.Tensor: if self._num_examples == 0: raise NotComputableError("SSIM must have at least one example before it can be computed.") - return (self._sum_of_ssim / self._num_examples).item() + return self._sum_of_ssim / self._num_examples From be893ce7fb3593d34361657844dd7f76920753ce Mon Sep 17 00:00:00 2001 From: crj1998 Date: Thu, 8 Dec 2022 08:16:55 +0000 Subject: [PATCH 2/2] add unit test --- tests/ignite/metrics/test_ssim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index f477892222cb..335c2082cb08 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -96,8 +96,8 @@ def test_ssim(device, shape, kernel_size, gaussian, use_sample_covariance): use_sample_covariance=use_sample_covariance, ) - assert isinstance(ignite_ssim, float) - assert np.allclose(ignite_ssim, skimg_ssim, atol=7e-5) + assert isinstance(ignite_ssim, torch.Tensor) + assert np.allclose(ignite_ssim.item(), skimg_ssim, atol=7e-5) def test_ssim_variable_batchsize(): @@ -123,7 +123,7 @@ def test_ssim_variable_batchsize(): ssim.reset() ssim.update((torch.cat(y_preds), torch.cat(y_true))) expected = ssim.compute() - assert np.allclose(out, expected) + assert np.allclose(out.item(), expected.item()) def _test_distrib_integration(device, tol=1e-4):