Skip to content

Commit ba3ea5c

Browse files
committed
Remove KMeansMG
I _think_ we can fully remove `KMeansMG`. As is, KMeansMG is a thin wrapper around `KMeans` itself, with just the `fit` method reimplemented. Looking at the implementation though, all it does is call `cuvs::cluster::kmeans::fit` (with much less input validation than it should) followed by `cuvs::cluster::kmeans::predict` instead of a single call to `cuvs::cluster::kmeans::fit_predict` (like `KMeans` does). Reading through the cuvs docs, I don't see a strong reason why we can't just use `fit_predict` everywhere. Ripping out `KMeansMG` does lead all tests to pass.
1 parent 87d7bb7 commit ba3ea5c

3 files changed

Lines changed: 2 additions & 233 deletions

File tree

python/cuml/cuml/cluster/CMakeLists.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,6 @@ add_module_gpu_default("agglomerative.pyx" ${agglomerative_algo} ${cluster_algo}
1717
add_module_gpu_default("dbscan.pyx" ${dbscan_algo} ${cluster_algo})
1818
add_module_gpu_default("kmeans.pyx" ${kmeans_algo} ${cluster_algo})
1919

20-
if(NOT SINGLEGPU)
21-
list(APPEND cython_sources
22-
kmeans_mg.pyx
23-
)
24-
endif()
25-
2620
add_subdirectory(hdbscan)
2721

2822
rapids_cython_create_modules(

python/cuml/cuml/cluster/kmeans_mg.pyx

Lines changed: 0 additions & 225 deletions
This file was deleted.

python/cuml/cuml/dask/cluster/kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, *, client=None, verbose=False, **kwargs):
9898
@staticmethod
9999
@mnmg_import
100100
def _func_fit(sessionId, objs, datatype, has_weights, **kwargs):
101-
from cuml.cluster.kmeans_mg import KMeansMG as cumlKMeans
101+
from cuml.cluster.kmeans import KMeans
102102

103103
handle = get_raft_comm_state(sessionId, get_worker())["handle"]
104104

@@ -109,7 +109,7 @@ def _func_fit(sessionId, objs, datatype, has_weights, **kwargs):
109109
inp_data = concatenate([X for X, weights in objs])
110110
inp_weights = concatenate([weights for X, weights in objs])
111111

112-
return cumlKMeans(handle=handle, output_type=datatype, **kwargs).fit(
112+
return KMeans(handle=handle, output_type=datatype, **kwargs).fit(
113113
inp_data, sample_weight=inp_weights
114114
)
115115

0 commit comments

Comments
 (0)