Skip to content
Merged
91 changes: 59 additions & 32 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

import warnings

_DATA_ON_HOST_DEPRECATED_MESSAGE = (
"The data_on_host option is deprecated and will be removed in release 25.10. "
"Whether data is on host or device is now determined by the build_algo."
)

import cupy
import cupyx.scipy.sparse
import joblib
Expand Down Expand Up @@ -303,8 +308,8 @@ class UMAP(Base,
but longer runtime.

- `nnd_n_clusters` (int, default=1): Number of clusters for data partitioning.
Higher values reduce memory usage at the cost of accuracy. When `nnd_n_clusters > 1`, data must be on host memory.
Refer to data_on_host argument for fit_transform function.
Higher values reduce memory usage at the cost of accuracy. When `nnd_n_clusters > 1`,
UMAP can process data larger than device memory.

- `nnd_overlap_factor` (int, default=2): Number of clusters each data point belongs to.
Valid only when `nnd_n_clusters > 1`. Must be < 'nnd_n_clusters'.
Expand Down Expand Up @@ -628,6 +633,8 @@ class UMAP(Base,
raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % build_algo)

self.build_kwds = build_kwds
if self.build_kwds and self.build_kwds.get("nnd_n_clusters", 1) < 1:
raise ValueError("nnd_n_clusters must be >= 1")

def validate_hyperparams(self):

Expand Down Expand Up @@ -721,7 +728,7 @@ class UMAP(Base,
@generate_docstring(convert_dtype_cast='np.float32',
X='dense_sparse',
skip_parameters_heading=True)
def fit(self, X, y=None, *, convert_dtype=True, knn_graph=None, data_on_host=False) -> "UMAP":
def fit(self, X, y=None, *, convert_dtype=True, knn_graph=None, data_on_host="auto") -> "UMAP":
"""
Fit X into an embedded space.

Expand All @@ -737,9 +744,11 @@ class UMAP(Base,
should match the metric used to train the UMAP embeedings.
Takes precedence over the precomputed_knn parameter.

.. deprecated:: 25.06
Using `nnd_n_clusters>1` with data on device is deprecated in version 25.06
and will be removed in 25.08. Set `data_on_host=True` when `nnd_n_clusters>1`."
.. deprecated:: 25.08
The `data_on_host` parameter is deprecated and will be removed in release 25.10.
Whether data is on host or device is now determined by the `nnd_n_clusters` parameter.
When `build_algo == nn_descent`, data will automatically be placed on host memory.
When `build_algo == brute_force_knn`, data will automatically be placed on device memory.
"""
if len(X.shape) != 2:
raise ValueError("data should be two dimensional")
Expand All @@ -749,6 +758,16 @@ class UMAP(Base,
raise ValueError("Cannot provide a KNN graph when in \
semi-supervised mode with categorical target_metric for now.")

# Set build_algo based on n_rows
if self.build_algo == "auto":
if X.shape[0] <= 50000 or self.sparse_fit:
# brute force is faster for small datasets
logger.info("Building knn graph using brute force (configured from build_algo == 'auto')")
self.build_algo = "brute_force_knn"
else:
logger.info("Building knn graph using nn descent (configured from build_algo == 'auto')")
self.build_algo = "nn_descent"

# Handle sparse inputs
if is_sparse(X):

Expand All @@ -763,19 +782,12 @@ class UMAP(Base,
# Handle dense inputs
else:
self._sparse_data = False
if data_on_host:

# automatically put data on host for nn descent regardless of nnd_n_clusters
if self.build_algo == "nn_descent":
convert_to_mem_type = MemoryType.host
Comment thread
jcrist marked this conversation as resolved.
else:
build_kwds = self.build_kwds or {}
if build_kwds.get("nnd_n_clusters", 1) > 1:
warnings.warn(
("Using nnd_n_clusters>1 with data on device is deprecated in version 25.06"
" and will be removed in 25.08. Set data_on_host=True when nnd_n_clusters>1."),
FutureWarning,
)
convert_to_mem_type = MemoryType.host
else:
convert_to_mem_type = MemoryType.device
convert_to_mem_type = MemoryType.device

self._raw_data, self.n_rows, self.n_dims, _ = \
input_to_cuml_array(X, order='C', check_dtype=np.float32,
Expand All @@ -784,17 +796,30 @@ class UMAP(Base,
else None),
convert_to_mem_type=convert_to_mem_type)

if self.build_algo == "auto":
if self.n_rows <= 50000 or self.sparse_fit:
# brute force is faster for small datasets
logger.info("Building knn graph using brute force")
self.build_algo = "brute_force_knn"
else:
logger.info("Building knn graph using nn descent")
self.build_algo = "nn_descent"

if self.build_algo == "brute_force_knn" and data_on_host:
raise ValueError("Data cannot be on host for building with brute force knn")
# Get nnd_n_clusters value for validation
build_kwds = self.build_kwds or {}
nnd_n_clusters = build_kwds.get("nnd_n_clusters", 1)

# deprecation notice and raising error for data_on_host parameter
if data_on_host is True:
if self.build_algo == "brute_force_knn":
raise ValueError(
f"build_algo = 'brute_force_knn' is not supported when data_on_host is True; "
f"{_DATA_ON_HOST_DEPRECATED_MESSAGE}"
)
warnings.warn(_DATA_ON_HOST_DEPRECATED_MESSAGE, FutureWarning)
elif data_on_host is False:
if self.build_algo == "nn_descent" and nnd_n_clusters > 1:
raise ValueError(
f"nnd_n_clusters > 1 is not supported for nn_descent build when data_on_host is False; "
f"{_DATA_ON_HOST_DEPRECATED_MESSAGE}"
)
warnings.warn(_DATA_ON_HOST_DEPRECATED_MESSAGE, FutureWarning)
elif data_on_host != "auto":
raise ValueError(
f"data_on_host must be True, False, or 'auto'; "
f"{_DATA_ON_HOST_DEPRECATED_MESSAGE}"
)

if self.n_rows <= 1:
raise ValueError("There needs to be more than 1 sample to "
Expand Down Expand Up @@ -903,7 +928,7 @@ class UMAP(Base,
*,
convert_dtype=True,
knn_graph=None,
data_on_host=False,
data_on_host="auto",
) -> CumlArray:
"""
Fit X into an embedded space and return that transformed
Expand Down Expand Up @@ -936,9 +961,11 @@ class UMAP(Base,
Acceptable formats: sparse SciPy ndarray, CuPy device ndarray,
CSR/COO preferred other formats will go through conversion to CSR

.. deprecated:: 25.06
Using `nnd_n_clusters>1` with data on device is deprecated in version 25.06
and will be removed in 25.08. Set `data_on_host=True` when `nnd_n_clusters>1`."
.. deprecated:: 25.08
The `data_on_host` parameter is deprecated and will be removed in release 25.10.
Whether data is on host or device is now determined by the `nnd_n_clusters` parameter.
When `build_algo == nn_descent`, data will automatically be placed on host memory.
When `build_algo == brute_force_knn`, data will automatically be placed on device memory.
"""
self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph, data_on_host=data_on_host)

Expand Down
89 changes: 70 additions & 19 deletions python/cuml/cuml/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,13 +832,12 @@ def test_umap_distance_metrics_fit_transform_trust_on_sparse_input(
assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True)


@pytest.mark.parametrize("data_on_host", [True, False])
@pytest.mark.parametrize("num_clusters", [0, 3, 5])
@pytest.mark.parametrize("num_clusters", [3, 5])
@pytest.mark.parametrize("fit_then_transform", [False, True])
@pytest.mark.parametrize("metric", ["l2", "sqeuclidean", "cosine"])
@pytest.mark.parametrize("do_snmg", [True, False])
def test_umap_trustworthiness_on_batch_nnd(
data_on_host, num_clusters, fit_then_transform, metric, do_snmg
num_clusters, fit_then_transform, metric, do_snmg
):

digits = datasets.load_digits()
Expand All @@ -856,34 +855,86 @@ def test_umap_trustworthiness_on_batch_nnd(
metric=metric,
)

if fit_then_transform:
cuml_model.fit(digits.data, convert_dtype=True)
cuml_embedding = cuml_model.transform(digits.data)
else:
cuml_embedding = cuml_model.fit_transform(
digits.data, convert_dtype=True
)

cuml_trust = trustworthiness(
digits.data, cuml_embedding, n_neighbors=10, metric=metric
)

assert cuml_trust > 0.9


@pytest.mark.parametrize("data_on_host", [True, False, "auto", None])
@pytest.mark.parametrize("num_clusters", [0, 1, 5])
@pytest.mark.parametrize(
"build_algo,n_rows",
[
("brute_force_knn", 5000),
("nn_descent", 5000),
("auto", 5000), # results in brute_force_knn
# ("auto", 51000), # results in nn_descent, passes tests but trustworthiness takes long to run
],
)
def test_umap_param_handling(data_on_host, num_clusters, build_algo, n_rows):

data, _ = make_blobs(
n_samples=n_rows, n_features=64, centers=5, random_state=0
)

def run_umap():
if fit_then_transform:
cuml_model.fit(
digits.data, convert_dtype=True, data_on_host=data_on_host
)
cuml_embedding = cuml_model.transform(digits.data)
cuml_model = cuUMAP(
n_neighbors=10,
min_dist=0.01,
build_algo=build_algo,
build_kwds={"nnd_n_clusters": num_clusters},
metric="l2",
)

if data_on_host is None:
cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True)
else:
cuml_embedding = cuml_model.fit_transform(
digits.data, convert_dtype=True, data_on_host=data_on_host
data, convert_dtype=True, data_on_host=data_on_host
)

return cuml_embedding

# num clusters should be >= 1
if num_clusters == 0:
# eventual build_algo when given auto
configured_build_algo = build_algo
if configured_build_algo in ["auto", None]:
if n_rows <= 50000:
configured_build_algo = "brute_force_knn"
else:
configured_build_algo = "nn_descent"

if (
num_clusters == 0
or (
configured_build_algo == "brute_force_knn" and data_on_host is True
)
or (
configured_build_algo == "nn_descent"
and num_clusters > 1
and data_on_host is False
)
):
with pytest.raises(ValueError):
run_umap()
return

# data should be on host if batching (num_clusters > 1)
if num_clusters > 1 and not data_on_host:
with pytest.raises(Exception):
run_umap()
return
if data_on_host in [True, False]:
with pytest.warns(FutureWarning):
cuml_embedding = run_umap()
else:
cuml_embedding = run_umap()

cuml_embedding = run_umap()
cuml_trust = trustworthiness(
digits.data, cuml_embedding, n_neighbors=10, metric=metric
data, cuml_embedding, n_neighbors=10, metric="l2"
)

assert cuml_trust > 0.9
Expand Down
Loading