diff --git a/ignite/metrics/gan/utils.py b/ignite/metrics/gan/utils.py index 4f7e4de8e553..efb8f7597087 100644 --- a/ignite/metrics/gan/utils.py +++ b/ignite/metrics/gan/utils.py @@ -19,12 +19,13 @@ class InceptionModel(torch.nn.Module): def __init__(self, return_features: bool, device: Union[str, torch.device] = "cpu") -> None: try: + import torchvision from torchvision import models except ImportError: raise RuntimeError("This module requires torchvision to be installed.") super(InceptionModel, self).__init__() self._device = device - if Version(torch.__version__) <= Version("1.7.0"): + if Version(torchvision.__version__) < Version("0.13.0"): model_kwargs = {"pretrained": True} else: model_kwargs = {"weights": models.Inception_V3_Weights.DEFAULT} diff --git a/tests/ignite/handlers/test_state_param_scheduler.py b/tests/ignite/handlers/test_state_param_scheduler.py index 7b2ec23a7024..a63338d989b1 100644 --- a/tests/ignite/handlers/test_state_param_scheduler.py +++ b/tests/ignite/handlers/test_state_param_scheduler.py @@ -38,7 +38,7 @@ }, ) -if Version(torch.__version__) <= Version("1.7.0"): +if Version(torch.__version__) < Version("1.9.0"): torch_testing_assert_close = torch.testing.assert_allclose else: torch_testing_assert_close = torch.testing.assert_close