diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 30025be843f3..be1306ba938f 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -98,7 +98,7 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum_of_ssim = torch.tensor(0.0, device=self._device) + self._sum_of_ssim = torch.tensor(0.0, dtype=torch.float64, device=self._device) self._num_examples = 0 self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) @@ -180,7 +180,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_ssim", "_num_examples") - def compute(self) -> torch.Tensor: + def compute(self) -> float: if self._num_examples == 0: raise NotComputableError("SSIM must have at least one example before it can be computed.") - return self._sum_of_ssim / self._num_examples + return (self._sum_of_ssim / self._num_examples).item() diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index 570f2c57534a..fe00d6c36e48 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -96,10 +96,8 @@ def test_ssim(device, shape, kernel_size, gaussian, use_sample_covariance): use_sample_covariance=use_sample_covariance, ) - assert isinstance(ignite_ssim, torch.Tensor) - assert ignite_ssim.dtype == torch.float64 - assert ignite_ssim.device.type == torch.device(device).type - assert np.allclose(ignite_ssim.cpu().numpy(), skimg_ssim, atol=7e-5) + assert isinstance(ignite_ssim, float) + assert np.allclose(ignite_ssim, skimg_ssim, atol=7e-5) def test_ssim_variable_batchsize(): @@ -125,7 +123,7 @@ def test_ssim_variable_batchsize(): ssim.reset() ssim.update((torch.cat(y_preds), torch.cat(y_true))) expected = ssim.compute() - assert torch.allclose(out, expected) + assert np.allclose(out, expected) def _test_distrib_integration(device, tol=1e-4):