From e416630ca9c7b2f0b4e2c35958eb8322361b3e0c Mon Sep 17 00:00:00 2001 From: vic Date: Wed, 18 Mar 2026 15:19:30 +0100 Subject: [PATCH 1/6] Fix OOM in Dask KMeans by collecting only one model after fit --- python/cuml/cuml/dask/cluster/kmeans.py | 28 +++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/python/cuml/cuml/dask/cluster/kmeans.py b/python/cuml/cuml/dask/cluster/kmeans.py index 31148cf952..e90d8532c6 100644 --- a/python/cuml/cuml/dask/cluster/kmeans.py +++ b/python/cuml/cuml/dask/cluster/kmeans.py @@ -159,10 +159,30 @@ def fit(self, X, sample_weight=None): comms.destroy() - models = [res.result() for res in kmeans_fit] - first = models[0] - first.labels_ = cp.concatenate([model.labels_ for model in models]) - first.inertia_ = sum(model.inertia_ for model in models) + # Collect the full model from only the first worker (for + # cluster_centers_ etc). Extract labels_ and inertia_ from the + # remaining workers remotely to avoid pulling N redundant copies + # of cluster_centers_ back to the client. + first = kmeans_fit[0].result() + + remote_labels = [ + self.client.submit(getattr, f, "labels_", workers=[w]) + for f, (w, _) in zip( + kmeans_fit[1:], list(data.worker_to_parts.items())[1:] + ) + ] + remote_inertias = [ + self.client.submit(getattr, f, "inertia_", workers=[w]) + for f, (w, _) in zip( + kmeans_fit[1:], list(data.worker_to_parts.items())[1:] + ) + ] + + all_labels = [first.labels_] + self.client.gather(remote_labels) + all_inertias = [first.inertia_] + self.client.gather(remote_inertias) + + first.labels_ = cp.concatenate(all_labels) + first.inertia_ = sum(all_inertias) self._set_internal_model(first) return self From 55e35638ba645a379ada7f451953273a1eec0ec9 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 19 Mar 2026 11:45:16 +0100 Subject: [PATCH 2/6] extract labels and inertia as a single unified Dask task --- python/cuml/cuml/dask/cluster/kmeans.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/cuml/cuml/dask/cluster/kmeans.py b/python/cuml/cuml/dask/cluster/kmeans.py index e90d8532c6..3fa29a18c1 100644 --- a/python/cuml/cuml/dask/cluster/kmeans.py +++ b/python/cuml/cuml/dask/cluster/kmeans.py @@ -165,21 +165,18 @@ def fit(self, X, sample_weight=None): # of cluster_centers_ back to the client. first = kmeans_fit[0].result() - remote_labels = [ - self.client.submit(getattr, f, "labels_", workers=[w]) - for f, (w, _) in zip( - kmeans_fit[1:], list(data.worker_to_parts.items())[1:] + remote_results = [ + self.client.submit( + lambda m: (m.labels_, m.inertia_), f, workers=[w] ) - ] - remote_inertias = [ - self.client.submit(getattr, f, "inertia_", workers=[w]) for f, (w, _) in zip( kmeans_fit[1:], list(data.worker_to_parts.items())[1:] ) ] - all_labels = [first.labels_] + self.client.gather(remote_labels) - all_inertias = [first.inertia_] + self.client.gather(remote_inertias) + results = self.client.gather(remote_results) + all_labels = [first.labels_] + [r[0] for r in results] + all_inertias = [first.inertia_] + [r[1] for r in results] first.labels_ = cp.concatenate(all_labels) first.inertia_ = sum(all_inertias) From 1acf402d437517be13d70a6b96884f91da54b56a Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 19 Mar 2026 11:53:17 +0100 Subject: [PATCH 3/6] labels_ attribute as Dask array --- python/cuml/cuml/dask/cluster/kmeans.py | 53 ++++++++++++++++++------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/python/cuml/cuml/dask/cluster/kmeans.py b/python/cuml/cuml/dask/cluster/kmeans.py index 3fa29a18c1..7f9e464c62 100644 --- a/python/cuml/cuml/dask/cluster/kmeans.py +++ b/python/cuml/cuml/dask/cluster/kmeans.py @@ -3,6 +3,8 @@ # import cupy as cp +import dask +import dask.array as da from dask.distributed import get_worker from raft_dask.common.comms import Comms, get_raft_comm_state @@ -160,26 +162,47 @@ def fit(self, X, sample_weight=None): comms.destroy() # Collect the full model from only the first worker (for - # cluster_centers_ etc). Extract labels_ and inertia_ from the - # remaining workers remotely to avoid pulling N redundant copies - # of cluster_centers_ back to the client. + # cluster_centers_ etc). Since cluster centers are synchronized + # via NCCL, all workers have identical copies — pulling more + # than one would waste memory (N * n_clusters * n_features * 4B). + # + # Labels stay distributed as a dask.array to avoid transferring + # per-sample data to the client. Only the scalar inertia values + # are gathered. first = kmeans_fit[0].result() + workers = list(data.worker_to_parts.keys()) - remote_results = [ - self.client.submit( - lambda m: (m.labels_, m.inertia_), f, workers=[w] - ) - for f, (w, _) in zip( - kmeans_fit[1:], list(data.worker_to_parts.items())[1:] - ) + remote_labels = [ + self.client.submit(getattr, f, "labels_", workers=[w]) + for f, w in zip(kmeans_fit[1:], workers[1:]) + ] + remote_inertias = [ + self.client.submit(getattr, f, "inertia_", workers=[w]) + for f, w in zip(kmeans_fit[1:], workers[1:]) ] - results = self.client.gather(remote_results) - all_labels = [first.labels_] + [r[0] for r in results] - all_inertias = [first.inertia_] + [r[1] for r in results] + first.inertia_ += sum(self.client.gather(remote_inertias)) - first.labels_ = cp.concatenate(all_labels) - first.inertia_ = sum(all_inertias) + labels_dtype = first.labels_.dtype + label_chunks = [ + da.from_delayed( + dask.delayed(first.labels_, pure=True, traverse=False), + shape=(first.labels_.shape[0],), + dtype=labels_dtype, + meta=cp.zeros(0, dtype=labels_dtype), + ) + ] + [ + da.from_delayed( + f, + shape=(float("nan"),), + dtype=labels_dtype, + meta=cp.zeros(0, dtype=labels_dtype), + ) + for f in remote_labels + ] + first.labels_ = da.concatenate( + label_chunks, allow_unknown_chunksizes=True + ) self._set_internal_model(first) return self From dfd461a15cef7d34bbb0fd8c0b9ec4dadb564886 Mon Sep 17 00:00:00 2001 From: vic Date: Thu, 19 Mar 2026 16:13:35 +0100 Subject: [PATCH 4/6] storing attributes on directly on Dask estimator --- python/cuml/cuml/dask/cluster/kmeans.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/dask/cluster/kmeans.py b/python/cuml/cuml/dask/cluster/kmeans.py index 7f9e464c62..22bfd18f37 100644 --- a/python/cuml/cuml/dask/cluster/kmeans.py +++ b/python/cuml/cuml/dask/cluster/kmeans.py @@ -181,7 +181,9 @@ def fit(self, X, sample_weight=None): for f, w in zip(kmeans_fit[1:], workers[1:]) ] - first.inertia_ += sum(self.client.gather(remote_inertias)) + self.inertia_ = first.inertia_ + sum( + self.client.gather(remote_inertias) + ) labels_dtype = first.labels_.dtype label_chunks = [ @@ -200,9 +202,10 @@ def fit(self, X, sample_weight=None): ) for f in remote_labels ] - first.labels_ = da.concatenate( + self.labels_ = da.concatenate( label_chunks, allow_unknown_chunksizes=True ) + self._set_internal_model(first) return self From fe977b79f30828daa21d881ff71a7c0703511831 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 19 Mar 2026 12:13:01 -0500 Subject: [PATCH 5/6] Fixup --- python/cuml/cuml/dask/cluster/kmeans.py | 58 ++++++++++--------------- 1 file changed, 23 insertions(+), 35 deletions(-) diff --git a/python/cuml/cuml/dask/cluster/kmeans.py b/python/cuml/cuml/dask/cluster/kmeans.py index 22bfd18f37..e7e90a5582 100644 --- a/python/cuml/cuml/dask/cluster/kmeans.py +++ b/python/cuml/cuml/dask/cluster/kmeans.py @@ -3,7 +3,6 @@ # import cupy as cp -import dask import dask.array as da from dask.distributed import get_worker from raft_dask.common.comms import Comms, get_raft_comm_state @@ -19,6 +18,10 @@ from cuml.internals.validation import check_random_seed +def _get_inertia_and_n_samples(estimator): + return (estimator.inertia_, len(estimator.labels_)) + + class KMeans(BaseEstimator, DelayedPredictionMixin, DelayedTransformMixin): """ Multi-Node Multi-GPU implementation of KMeans. @@ -165,49 +168,34 @@ def fit(self, X, sample_weight=None): # cluster_centers_ etc). Since cluster centers are synchronized # via NCCL, all workers have identical copies — pulling more # than one would waste memory (N * n_clusters * n_features * 4B). - # - # Labels stay distributed as a dask.array to avoid transferring - # per-sample data to the client. Only the scalar inertia values - # are gathered. first = kmeans_fit[0].result() - workers = list(data.worker_to_parts.keys()) + self._set_internal_model(first) - remote_labels = [ - self.client.submit(getattr, f, "labels_", workers=[w]) - for f, w in zip(kmeans_fit[1:], workers[1:]) - ] - remote_inertias = [ - self.client.submit(getattr, f, "inertia_", workers=[w]) - for f, w in zip(kmeans_fit[1:], workers[1:]) - ] + workers = list(data.worker_to_parts.keys()) - self.inertia_ = first.inertia_ + sum( - self.client.gather(remote_inertias) + # Compute and store the total inertia_ + inertia_and_lengths = self.client.gather( + [ + self.client.submit(_get_inertia_and_n_samples, f, workers=[w]) + for f, w in zip(kmeans_fit, workers) + ] ) + self.inertia_ = sum(inertia for inertia, _ in inertia_and_lengths) - labels_dtype = first.labels_.dtype - label_chunks = [ - da.from_delayed( - dask.delayed(first.labels_, pure=True, traverse=False), - shape=(first.labels_.shape[0],), - dtype=labels_dtype, - meta=cp.zeros(0, dtype=labels_dtype), - ) - ] + [ - da.from_delayed( - f, - shape=(float("nan"),), - dtype=labels_dtype, - meta=cp.zeros(0, dtype=labels_dtype), - ) - for f in remote_labels + # Store labels_ as a distributed dask array. This attribute scales with + # n_samples, and shouldn't be pulled back to a local node. + labels_meta = cp.zeros(0, dtype=first.labels_.dtype) + labels = [ + self.client.submit(getattr, f, "labels_", workers=[w]) + for f, w in zip(kmeans_fit, workers) ] self.labels_ = da.concatenate( - label_chunks, allow_unknown_chunksizes=True + [ + da.from_delayed(f, shape=(length,), meta=labels_meta) + for f, (_, length) in zip(labels, inertia_and_lengths) + ] ) - self._set_internal_model(first) - return self def fit_predict(self, X, sample_weight=None, delayed=True): From e6fbddd3dd7e1309b73bc5bc348124d509dfe841 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 19 Mar 2026 12:54:35 -0500 Subject: [PATCH 6/6] Optimize `fit_predict` and fix `labels_` for dataframes --- python/cuml/cuml/dask/cluster/kmeans.py | 21 ++++++++------- python/cuml/tests/dask/test_dask_kmeans.py | 31 +++++++++++++--------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/python/cuml/cuml/dask/cluster/kmeans.py b/python/cuml/cuml/dask/cluster/kmeans.py index e7e90a5582..0189cf1346 100644 --- a/python/cuml/cuml/dask/cluster/kmeans.py +++ b/python/cuml/cuml/dask/cluster/kmeans.py @@ -4,6 +4,7 @@ import cupy as cp import dask.array as da +import dask.dataframe as dd from dask.distributed import get_worker from raft_dask.common.comms import Comms, get_raft_comm_state @@ -189,12 +190,15 @@ def fit(self, X, sample_weight=None): self.client.submit(getattr, f, "labels_", workers=[w]) for f, w in zip(kmeans_fit, workers) ] - self.labels_ = da.concatenate( - [ - da.from_delayed(f, shape=(length,), meta=labels_meta) - for f, (_, length) in zip(labels, inertia_and_lengths) - ] - ) + if self.datatype == "cudf": + self.labels_ = dd.from_delayed(labels) + else: + self.labels_ = da.concatenate( + [ + da.from_delayed(f, shape=(length,), meta=labels_meta) + for f, (_, length) in zip(labels, inertia_and_lengths) + ] + ) return self @@ -213,9 +217,8 @@ def fit_predict(self, X, sample_weight=None, delayed=True): Distributed object containing predictions """ - return self.fit(X, sample_weight=sample_weight).predict( - X, delayed=delayed - ) + self.fit(X, sample_weight=sample_weight) + return self.labels_ if delayed else self.labels_.persist() def predict(self, X, delayed=True): """ diff --git a/python/cuml/tests/dask/test_dask_kmeans.py b/python/cuml/tests/dask/test_dask_kmeans.py index 412dc6bd96..9de4552910 100644 --- a/python/cuml/tests/dask/test_dask_kmeans.py +++ b/python/cuml/tests/dask/test_dask_kmeans.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -29,7 +29,7 @@ def test_end_to_end( nrows, ncols, nclusters, n_parts, delayed_predict, input_type, client ): - from cuml.dask.cluster import KMeans as cumlKMeans + from cuml.dask.cluster import KMeans from cuml.dask.datasets import make_blobs X, y = make_blobs( @@ -47,15 +47,15 @@ def test_end_to_end( elif input_type == "array": X_train, y_train = X, y - cumlModel = cumlKMeans( + model = KMeans( init="k-means||", n_clusters=nclusters, random_state=10, n_init="auto", ) - cumlModel.fit(X_train) - cumlLabels = cumlModel.predict(X_train, delayed=delayed_predict) + dask_fit_predict_labels = model.fit_predict(X_train) + dask_predict_labels = model.predict(X_train, delayed=delayed_predict) n_workers = len(list(client.has_what().keys())) @@ -66,19 +66,24 @@ def test_end_to_end( parts_len = n_workers if input_type == "dataframe": - assert cumlLabels.npartitions == parts_len - cumlPred = cumlLabels.compute().values + assert dask_predict_labels.npartitions == parts_len + pred_labels = dask_predict_labels.compute().values + fit_pred_labels = dask_fit_predict_labels.compute().values labels = y_train.compute().values elif input_type == "array": - assert len(cumlLabels.chunks[0]) == parts_len - cumlPred = cp.array(cumlLabels.compute()) + assert len(dask_predict_labels.chunks[0]) == parts_len + pred_labels = cp.array(dask_predict_labels.compute()) + fit_pred_labels = cp.array(dask_fit_predict_labels.compute()) labels = cp.squeeze(y_train.compute()) - assert cumlPred.shape[0] == nrows - assert cp.max(cumlPred) == nclusters - 1 - assert cp.min(cumlPred) == 0 + assert pred_labels.shape[0] == nrows + assert cp.max(pred_labels) == nclusters - 1 + assert cp.min(pred_labels) == 0 + + # Assert fit_predict(X) and fit(X).predict(X) have same result + cp.testing.assert_array_equal(pred_labels, fit_pred_labels) - score = adjusted_rand_score(labels, cumlPred) + score = adjusted_rand_score(labels, pred_labels) assert 1.0 == score