diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index c68a94c91d3..ae7ffc7a8fb 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -33,6 +33,7 @@ is_bnb_available, is_clearml_available, is_comet_ml_available, + is_cuda_available, is_datasets_available, is_deepspeed_available, is_dvclive_available, @@ -51,7 +52,7 @@ def get_backend(): - if torch.cuda.is_available(): + if is_cuda_available(): return "cuda", torch.cuda.device_count() elif is_mps_available(): return "mps", 1 @@ -117,7 +118,7 @@ def require_cuda(test_case): """ Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available. """ - return unittest.skipUnless(torch.cuda.is_available(), "test requires a GPU")(test_case) + return unittest.skipUnless(is_cuda_available(), "test requires a GPU")(test_case) def require_xpu(test_case): diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 14ed9f7328c..ef64395ccfc 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -130,7 +130,7 @@ def is_bf16_available(ignore_tpu=False): "Checks if bf16 is supported, optionally ignoring the TPU" if is_tpu_available(): return not ignore_tpu - if torch.cuda.is_available(): + if is_cuda_available(): return torch.cuda.is_bf16_supported() return True