Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions python/cuml/cuml/manifold/simpl_set.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ cupyx = gpu_only_import('cupyx')

from cuml.manifold.umap_utils cimport *
from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \
metric_parsing
coerce_metric

from cuml.internals.input_utils import input_to_cuml_array, is_array_like
from cuml.internals.array import CumlArray
Expand Down Expand Up @@ -158,10 +158,7 @@ def fuzzy_simplicial_set(X,
umap_params.deterministic = <bool> deterministic
umap_params.set_op_mix_ratio = <float> set_op_mix_ratio
umap_params.local_connectivity = <float> local_connectivity
try:
umap_params.metric = metric_parsing[metric.lower()]
except KeyError:
raise ValueError(f"Invalid value for metric: {metric}")
umap_params.metric = coerce_metric(metric)
if metric_kwds is None:
umap_params.p = <float> 2.0
else:
Expand Down Expand Up @@ -353,10 +350,8 @@ def simplicial_set_embedding(
umap_params.init = <int> 1
else:
raise ValueError("Invalid initialization strategy")
try:
umap_params.metric = metric_parsing[metric.lower()]
except KeyError:
raise ValueError(f"Invalid value for metric: {metric}")

umap_params.metric = coerce_metric(metric)
if metric_kwds is None:
umap_params.p = <float> 2.0
else:
Expand Down
66 changes: 26 additions & 40 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@ IF GPUBUILD == 1:
from libc.stdlib cimport free
from cuml.manifold.umap_utils cimport *
from pylibraft.common.handle cimport handle_t
from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \
metric_parsing, DENSE_SUPPORTED_METRICS, SPARSE_SUPPORTED_METRICS

from cuml.manifold.simpl_set import fuzzy_simplicial_set, \
simplicial_set_embedding
from cuml.manifold.umap_utils import GraphHolder, find_ab_params, coerce_metric
from cuml.manifold.simpl_set import fuzzy_simplicial_set, simplicial_set_embedding

cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP":

Expand Down Expand Up @@ -483,56 +480,45 @@ class UMAP(UniversalBase,
umap_params.verbosity = <level_enum> self.verbose
umap_params.a = <float> self.a
umap_params.b = <float> self.b
umap_params.target_n_neighbors = <int> self.target_n_neighbors
umap_params.target_weight = <float> self.target_weight
umap_params.random_state = <uint64_t> check_random_seed(self.random_state)
umap_params.deterministic = <bool> self.deterministic

if self.init == "spectral":
umap_params.init = <int> 1
else: # self.init == "random"
umap_params.init = <int> 0
umap_params.target_n_neighbors = <int> self.target_n_neighbors

if self.target_metric == "euclidean":
umap_params.target_metric = MetricType.EUCLIDEAN
else: # self.target_metric == "categorical"
umap_params.target_metric = MetricType.CATEGORICAL
if self.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else: # self.init == "nn_descent"
umap_params.build_algo = graph_build_algo.NN_DESCENT
if self.build_kwds is None:
umap_params.nn_descent_params.graph_degree = <uint64_t> 64
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> 128
umap_params.nn_descent_params.max_iterations = <uint64_t> 20
umap_params.nn_descent_params.termination_threshold = <float> 0.0001
umap_params.nn_descent_params.return_distances = <bool> True
umap_params.nn_descent_params.n_clusters = <uint64_t> 1
else:
umap_params.nn_descent_params.graph_degree = <uint64_t> self.build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> self.build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> self.build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> self.build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> self.build_kwds.get("nnd_return_distances", True)
if self.build_kwds.get("nnd_n_clusters", 1) < 1:
logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1")
umap_params.nn_descent_params.n_clusters = <uint64_t> self.build_kwds.get("nnd_n_clusters", 1)

umap_params.target_weight = <float> self.target_weight
umap_params.random_state = <uint64_t> check_random_seed(self.random_state)
umap_params.deterministic = <bool> self.deterministic

try:
umap_params.metric = metric_parsing[self.metric.lower()]
if sparse:
if umap_params.metric not in SPARSE_SUPPORTED_METRICS:
raise NotImplementedError(f"Metric '{self.metric}' not supported for sparse inputs.")
elif umap_params.metric not in DENSE_SUPPORTED_METRICS:
raise NotImplementedError(f"Metric '{self.metric}' not supported for dense inputs.")

except KeyError:
raise ValueError(f"Invalid value for metric: {self.metric}")
umap_params.metric = coerce_metric(
self.metric, sparse=sparse, build_algo=self.build_algo
)

if self.metric_kwds is None:
umap_params.p = <float> 2.0
else:
umap_params.p = <float>self.metric_kwds.get('p')

if self.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else:
Comment thread
jcrist marked this conversation as resolved.
umap_params.build_algo = graph_build_algo.NN_DESCENT
build_kwds = self.build_kwds or {}
umap_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> build_kwds.get("nnd_return_distances", True)
umap_params.nn_descent_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1)
# Forward metric & metric_kwds to nn_descent
umap_params.nn_descent_params.metric = <RaftDistanceType> umap_params.metric
umap_params.nn_descent_params.metric_arg = umap_params.p
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual fix is here (plumbing through the metric options to nn_descent_params). In the long run we should redesign the C++ layer to remove duplicate options - for now just ensuring they're forwarded correctly seems sufficient.

Everything else here is a simplification of the current pre-existing code.


cdef uintptr_t callback_ptr = 0
if self.callback:
callback_ptr = self.callback.get_native_callback()
Expand Down
4 changes: 4 additions & 0 deletions python/cuml/cuml/manifold/umap_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ from libc.stdint cimport uint64_t, uintptr_t, int64_t
from libcpp cimport bool
from libcpp.memory cimport shared_ptr
from cuml.metrics.distance_type cimport DistanceType
from cuml.metrics.raft_distance_type cimport DistanceType as RaftDistanceType
from cuml.internals.logger cimport level_enum

cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams":
Expand All @@ -39,6 +40,7 @@ cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals":

cdef cppclass GraphBasedDimRedCallback


cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbors::experimental::nn_descent":
cdef struct index_params:
uint64_t graph_degree,
Expand All @@ -47,6 +49,8 @@ cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbor
float termination_threshold,
bool return_distances,
uint64_t n_clusters,
RaftDistanceType metric,
float metric_arg

cdef extern from "cuml/manifold/umapparams.h" namespace "ML":

Expand Down
89 changes: 60 additions & 29 deletions python/cuml/cuml/manifold/umap_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

# distutils: language = c++

from typing import Literal

from rmm.pylibrmm.memory_resource cimport get_current_device_resource
from pylibraft.common.handle cimport handle_t
from cuml.manifold.umap_utils cimport *
Expand Down Expand Up @@ -134,7 +136,7 @@ def find_ab_params(spread, min_dist):
return params[0], params[1]


metric_parsing = {
_METRICS = {
"l2": DistanceType.L2SqrtExpanded,
"euclidean": DistanceType.L2SqrtExpanded,
"sqeuclidean": DistanceType.L2Expanded,
Expand All @@ -153,32 +155,61 @@ metric_parsing = {
"canberra": DistanceType.Canberra
}

_SUPPORTED_METRICS = {
"nn_descent": {
"sparse": frozenset(),
"dense": frozenset((DistanceType.L2SqrtExpanded,))
},
"brute_force_knn": {
"sparse": frozenset((
DistanceType.Canberra,
DistanceType.CorrelationExpanded,
DistanceType.CosineExpanded,
DistanceType.HammingUnexpanded,
DistanceType.HellingerExpanded,
DistanceType.JaccardExpanded,
DistanceType.L1,
DistanceType.L2SqrtExpanded,
DistanceType.L2Expanded,
DistanceType.Linf,
DistanceType.LpUnexpanded,
)),
"dense": frozenset((
DistanceType.Canberra,
DistanceType.CorrelationExpanded,
DistanceType.CosineExpanded,
DistanceType.HammingUnexpanded,
DistanceType.HellingerExpanded,
# DistanceType.JaccardExpanded, # not supported
DistanceType.L1,
DistanceType.L2SqrtExpanded,
DistanceType.L2Expanded,
DistanceType.Linf,
DistanceType.LpUnexpanded,
))
}
}


DENSE_SUPPORTED_METRICS = [
DistanceType.Canberra,
DistanceType.CorrelationExpanded,
DistanceType.CosineExpanded,
DistanceType.HammingUnexpanded,
DistanceType.HellingerExpanded,
# DistanceType.JaccardExpanded, # not supported
DistanceType.L1,
DistanceType.L2SqrtExpanded,
DistanceType.L2Expanded,
DistanceType.Linf,
DistanceType.LpUnexpanded,
]


SPARSE_SUPPORTED_METRICS = [
DistanceType.Canberra,
DistanceType.CorrelationExpanded,
DistanceType.CosineExpanded,
DistanceType.HammingUnexpanded,
DistanceType.HellingerExpanded,
DistanceType.JaccardExpanded,
DistanceType.L1,
DistanceType.L2SqrtExpanded,
DistanceType.L2Expanded,
DistanceType.Linf,
DistanceType.LpUnexpanded,
]
def coerce_metric(
metric: str,
sparse: bool = False,
build_algo: Literal["brute_force_knn", "nn_descent"] = "brute_force_knn",
) -> DistanceType:
"""Coerce a metric string to a `DistanceType`.

Also checks that the metric is valid and supported.
"""
try:
out = _METRICS[metric.lower()]
except KeyError:
raise ValueError(f"Invalid value for metric: {metric!r}")

kind = "sparse" if sparse else "dense"
supported = _SUPPORTED_METRICS[build_algo][kind]
if out not in supported:
raise NotImplementedError(
f"Metric {metric!r} not supported for {kind} inputs with {build_algo=}"
)

return out
Loading