diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index f477892222cb..335c2082cb08 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -96,8 +96,8 @@ def test_ssim(device, shape, kernel_size, gaussian, use_sample_covariance): use_sample_covariance=use_sample_covariance, ) - assert isinstance(ignite_ssim, float) - assert np.allclose(ignite_ssim, skimg_ssim, atol=7e-5) + assert isinstance(ignite_ssim, torch.Tensor) + assert np.allclose(ignite_ssim.item(), skimg_ssim, atol=7e-5) def test_ssim_variable_batchsize(): @@ -123,7 +123,7 @@ def test_ssim_variable_batchsize(): ssim.reset() ssim.update((torch.cat(y_preds), torch.cat(y_true))) expected = ssim.compute() - assert np.allclose(out, expected) + assert np.allclose(out.item(), expected.item()) def _test_distrib_integration(device, tol=1e-4):