-
Notifications
You must be signed in to change notification settings - Fork 623
Plumb metric and metric_kwds through to UMAP with nn_descent
#6304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7130db3
2e99294
e0542df
e65b152
dfd6c65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,11 +74,8 @@ IF GPUBUILD == 1: | |
| from libc.stdlib cimport free | ||
| from cuml.manifold.umap_utils cimport * | ||
| from pylibraft.common.handle cimport handle_t | ||
| from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \ | ||
| metric_parsing, DENSE_SUPPORTED_METRICS, SPARSE_SUPPORTED_METRICS | ||
|
|
||
| from cuml.manifold.simpl_set import fuzzy_simplicial_set, \ | ||
| simplicial_set_embedding | ||
| from cuml.manifold.umap_utils import GraphHolder, find_ab_params, coerce_metric | ||
| from cuml.manifold.simpl_set import fuzzy_simplicial_set, simplicial_set_embedding | ||
|
|
||
| cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": | ||
|
|
||
|
|
@@ -483,56 +480,45 @@ class UMAP(UniversalBase, | |
| umap_params.verbosity = <level_enum> self.verbose | ||
| umap_params.a = <float> self.a | ||
| umap_params.b = <float> self.b | ||
| umap_params.target_n_neighbors = <int> self.target_n_neighbors | ||
| umap_params.target_weight = <float> self.target_weight | ||
| umap_params.random_state = <uint64_t> check_random_seed(self.random_state) | ||
| umap_params.deterministic = <bool> self.deterministic | ||
|
|
||
| if self.init == "spectral": | ||
| umap_params.init = <int> 1 | ||
| else: # self.init == "random" | ||
| umap_params.init = <int> 0 | ||
| umap_params.target_n_neighbors = <int> self.target_n_neighbors | ||
|
|
||
| if self.target_metric == "euclidean": | ||
| umap_params.target_metric = MetricType.EUCLIDEAN | ||
| else: # self.target_metric == "categorical" | ||
| umap_params.target_metric = MetricType.CATEGORICAL | ||
| if self.build_algo == "brute_force_knn": | ||
| umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN | ||
| else: # self.init == "nn_descent" | ||
| umap_params.build_algo = graph_build_algo.NN_DESCENT | ||
| if self.build_kwds is None: | ||
| umap_params.nn_descent_params.graph_degree = <uint64_t> 64 | ||
| umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> 128 | ||
| umap_params.nn_descent_params.max_iterations = <uint64_t> 20 | ||
| umap_params.nn_descent_params.termination_threshold = <float> 0.0001 | ||
| umap_params.nn_descent_params.return_distances = <bool> True | ||
| umap_params.nn_descent_params.n_clusters = <uint64_t> 1 | ||
| else: | ||
| umap_params.nn_descent_params.graph_degree = <uint64_t> self.build_kwds.get("nnd_graph_degree", 64) | ||
| umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> self.build_kwds.get("nnd_intermediate_graph_degree", 128) | ||
| umap_params.nn_descent_params.max_iterations = <uint64_t> self.build_kwds.get("nnd_max_iterations", 20) | ||
| umap_params.nn_descent_params.termination_threshold = <float> self.build_kwds.get("nnd_termination_threshold", 0.0001) | ||
| umap_params.nn_descent_params.return_distances = <bool> self.build_kwds.get("nnd_return_distances", True) | ||
| if self.build_kwds.get("nnd_n_clusters", 1) < 1: | ||
| logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1") | ||
| umap_params.nn_descent_params.n_clusters = <uint64_t> self.build_kwds.get("nnd_n_clusters", 1) | ||
|
|
||
| umap_params.target_weight = <float> self.target_weight | ||
| umap_params.random_state = <uint64_t> check_random_seed(self.random_state) | ||
| umap_params.deterministic = <bool> self.deterministic | ||
|
|
||
| try: | ||
| umap_params.metric = metric_parsing[self.metric.lower()] | ||
| if sparse: | ||
| if umap_params.metric not in SPARSE_SUPPORTED_METRICS: | ||
| raise NotImplementedError(f"Metric '{self.metric}' not supported for sparse inputs.") | ||
| elif umap_params.metric not in DENSE_SUPPORTED_METRICS: | ||
| raise NotImplementedError(f"Metric '{self.metric}' not supported for dense inputs.") | ||
|
|
||
| except KeyError: | ||
| raise ValueError(f"Invalid value for metric: {self.metric}") | ||
| umap_params.metric = coerce_metric( | ||
| self.metric, sparse=sparse, build_algo=self.build_algo | ||
| ) | ||
|
|
||
| if self.metric_kwds is None: | ||
| umap_params.p = <float> 2.0 | ||
| else: | ||
| umap_params.p = <float>self.metric_kwds.get('p') | ||
|
|
||
| if self.build_algo == "brute_force_knn": | ||
| umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN | ||
| else: | ||
| umap_params.build_algo = graph_build_algo.NN_DESCENT | ||
| build_kwds = self.build_kwds or {} | ||
| umap_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64) | ||
| umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128) | ||
| umap_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20) | ||
| umap_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001) | ||
| umap_params.nn_descent_params.return_distances = <bool> build_kwds.get("nnd_return_distances", True) | ||
| umap_params.nn_descent_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1) | ||
| # Forward metric & metric_kwds to nn_descent | ||
| umap_params.nn_descent_params.metric = <RaftDistanceType> umap_params.metric | ||
| umap_params.nn_descent_params.metric_arg = umap_params.p | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The actual fix is here (plumbing through the metric options to Everything else here is a simplification of the current pre-existing code. |
||
|
|
||
| cdef uintptr_t callback_ptr = 0 | ||
| if self.callback: | ||
| callback_ptr = self.callback.get_native_callback() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.