1919
2020from .magics import load_ipython_extension
2121
22+ from cuda .bindings import runtime
2223from cuml .internals import logger
2324from cuml .internals .global_settings import GlobalSettings
2425from cuml .internals .memory_utils import set_global_output_type
25- from cuml .internals .safe_imports import UnavailableError
26+ from cuml .internals .safe_imports import UnavailableError , gpu_only_import
27+
28+ rmm = gpu_only_import ("rmm" )
2629
2730__all__ = ["load_ipython_extension" , "install" ]
2831
@@ -31,11 +34,38 @@ def _install_for_library(library_name):
3134 importlib .import_module (f"._wrappers.{ library_name } " , __name__ )
3235
3336
34- def install ():
37+ def _is_concurrent_managed_access_supported ():
38+ """Check the availability of concurrent managed access (UVM).
39+ Note that WSL2 does not support managed memory.
40+ """
41+
42+ # Ensure CUDA is initialized before checking cudaDevAttrConcurrentManagedAccess
43+ runtime .cudaFree (0 )
44+
45+ device_id = 0
46+ err , supports_managed_access = runtime .cudaDeviceGetAttribute (
47+ runtime .cudaDeviceAttr .cudaDevAttrConcurrentManagedAccess , device_id
48+ )
49+ if err != runtime .cudaError_t .cudaSuccess :
50+ logger .error (
51+ f"Failed to check cudaDevAttrConcurrentManagedAccess with error { err } "
52+ )
53+ return False
54+ return supports_managed_access != 0
55+
56+
57+ def install (disable_uvm = False ):
3558 """Enable cuML Accelerator Mode."""
3659 logger .set_level (logger .level_enum .info )
3760 logger .set_pattern ("%v" )
3861
62+ if not disable_uvm :
63+ if _is_concurrent_managed_access_supported ():
64+ logger .debug ("cuML: Enabling managed memory..." )
65+ rmm .mr .set_current_device_resource (rmm .mr .ManagedMemoryResource ())
66+ else :
67+ logger .warn ("cuML: Could not enable managed memory." )
68+
3969 logger .debug ("cuML: Installing accelerator..." )
4070 libraries_to_accelerate = ["sklearn" , "umap" , "hdbscan" ]
4171 accelerated_libraries = []
0 commit comments