Skip to content

Commit 37a97ed

Browse files
committed
Don't setup managed memory if rmm isn't using default settings
Previously `cuml.accel` would always install a new `ManagedMemoryResource` on for the current device when installed. We now only do this if rmm is using the default resource. If a user or library has already configured rmm, we continue using their settings. This lets `cuml.accel` compose better with `cudf.pandas`, since we won't clobber their memory resource setup after loading ours. Also added tests for the `cuml.accel` IPython magics, including tests loading `cudf.pandas` before and after `cuml.accel`.
1 parent 30b3971 commit 37a97ed

2 files changed

Lines changed: 115 additions & 3 deletions

File tree

python/cuml/cuml/accel/core.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,24 @@ def install(disable_uvm=False):
8686
if _is_concurrent_managed_access_supported():
8787
import rmm
8888

89-
logger.debug("cuML: Enabling managed memory...")
90-
rmm.mr.set_current_device_resource(rmm.mr.ManagedMemoryResource())
89+
mr = rmm.mr.get_current_device_resource()
90+
if isinstance(mr, rmm.mr.ManagedMemoryResource):
91+
# Nothing to do
92+
pass
93+
elif not isinstance(mr, rmm.mr.CudaMemoryResource):
94+
logger.debug(
95+
"cuML: A non-default memory resource is already configured, "
96+
"skipping enabling managed memory."
97+
)
98+
else:
99+
rmm.mr.set_current_device_resource(
100+
rmm.mr.ManagedMemoryResource()
101+
)
102+
logger.debug("cuML: Enabled managed memory.")
91103
else:
92-
logger.warn("cuML: Could not enable managed memory.")
104+
logger.debug(
105+
"cuML: Could not enable managed memory on this platform."
106+
)
93107

94108
ACCEL.install()
95109
set_global_output_type("numpy")
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import subprocess
16+
import sys
17+
from textwrap import dedent
18+
19+
import pytest
20+
21+
pytest.importorskip("IPython")
22+
23+
24+
SCRIPT_HEADER = """
25+
from IPython.core.interactiveshell import InteractiveShell
26+
from traitlets.config import Config
27+
c = Config()
28+
c.HistoryManager.hist_file = ":memory:"
29+
ip = InteractiveShell(config=c)
30+
"""
31+
32+
33+
def run_script(body):
34+
script = SCRIPT_HEADER + dedent(body)
35+
36+
res = subprocess.run(
37+
[sys.executable, "-c", script],
38+
stderr=subprocess.STDOUT,
39+
stdout=subprocess.PIPE,
40+
text=True,
41+
)
42+
# Pull out attributes before assert for nicer error reporting on failure
43+
returncode = res.returncode
44+
stdout = res.stdout
45+
assert returncode == 0, stdout
46+
47+
48+
def test_magic():
49+
run_script(
50+
"""
51+
ip.run_line_magic("load_ext", "cuml.accel")
52+
53+
# cuml.accel proxies setup properly
54+
ip.run_cell("from sklearn.linear_model import LinearRegression")
55+
ip.run_cell("from cuml.accel import is_proxy")
56+
ip.run_cell("assert is_proxy(LinearRegression)").raise_error()
57+
"""
58+
)
59+
60+
61+
def test_magic_cudf_pandas_before():
62+
run_script(
63+
"""
64+
ip.run_line_magic("load_ext", "cudf.pandas")
65+
ip.run_cell("import rmm; mr = rmm.mr.get_current_device_resource();")
66+
67+
ip.run_line_magic("load_ext", "cuml.accel")
68+
ip.run_cell("mr2 = rmm.mr.get_current_device_resource();")
69+
70+
# cuml doesn't change the mr setup by cudf.pandas
71+
ip.run_cell("assert mr is mr2").raise_error()
72+
73+
# cuml.accel proxies setup properly
74+
ip.run_cell("from sklearn.linear_model import LinearRegression")
75+
ip.run_cell("from cuml.accel import is_proxy")
76+
result = ip.run_cell("assert is_proxy(LinearRegression)").raise_error()
77+
"""
78+
)
79+
80+
81+
def test_magic_cudf_pandas_after():
82+
run_script(
83+
"""
84+
ip.run_line_magic("load_ext", "cuml.accel")
85+
ip.run_cell("import rmm; mr = rmm.mr.get_current_device_resource();")
86+
87+
ip.run_line_magic("load_ext", "cudf.pandas")
88+
ip.run_cell("mr2 = rmm.mr.get_current_device_resource();")
89+
90+
# cudf.pandas doesn't change the mr setup by cuml.accel
91+
ip.run_cell("assert mr is mr2").raise_error()
92+
93+
# cuml.accel proxies setup properly
94+
ip.run_cell("from sklearn.linear_model import LinearRegression")
95+
ip.run_cell("from cuml.accel import is_proxy")
96+
result = ip.run_cell("assert is_proxy(LinearRegression)").raise_error()
97+
"""
98+
)

0 commit comments

Comments
 (0)