-
Notifications
You must be signed in to change notification settings - Fork 624
Fix OOM in Dask KMeans by collecting only one model after fit #7908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e416630
55e3563
1acf402
dfd461a
fe977b7
e6fbddd
62c9877
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,10 +159,30 @@ def fit(self, X, sample_weight=None): | |
|
|
||
| comms.destroy() | ||
|
|
||
| models = [res.result() for res in kmeans_fit] | ||
| first = models[0] | ||
| first.labels_ = cp.concatenate([model.labels_ for model in models]) | ||
| first.inertia_ = sum(model.inertia_ for model in models) | ||
| # Collect the full model from only the first worker (for | ||
| # cluster_centers_ etc). Extract labels_ and inertia_ from the | ||
| # remaining workers remotely to avoid pulling N redundant copies | ||
| # of cluster_centers_ back to the client. | ||
| first = kmeans_fit[0].result() | ||
|
|
||
| remote_labels = [ | ||
| self.client.submit(getattr, f, "labels_", workers=[w]) | ||
| for f, (w, _) in zip( | ||
| kmeans_fit[1:], list(data.worker_to_parts.items())[1:] | ||
| ) | ||
| ] | ||
| remote_inertias = [ | ||
| self.client.submit(getattr, f, "inertia_", workers=[w]) | ||
| for f, (w, _) in zip( | ||
| kmeans_fit[1:], list(data.worker_to_parts.items())[1:] | ||
| ) | ||
| ] | ||
|
|
||
| all_labels = [first.labels_] + self.client.gather(remote_labels) | ||
| all_inertias = [first.inertia_] + self.client.gather(remote_inertias) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's an overhead to tasks - it'd be better to submit one task that returns a tuple of |
||
|
|
||
| first.labels_ = cp.concatenate(all_labels) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, let's look at the target file and the specific lines mentioned
head -n 20 python/cuml/cuml/dask/cluster/kmeans.py | tail -n 10Repository: rapidsai/cuml Length of output: 404 🏁 Script executed: # Check the full context around lines 181-184
sed -n '1,20p' python/cuml/cuml/dask/cluster/kmeans.pyRepository: rapidsai/cuml Length of output: 692 🏁 Script executed: # Check if the file exists and get line count
wc -l python/cuml/cuml/dask/cluster/kmeans.pyRepository: rapidsai/cuml Length of output: 100 🏁 Script executed: # Read the relevant section around lines 181-184
sed -n '170,190p' python/cuml/cuml/dask/cluster/kmeans.pyRepository: rapidsai/cuml Length of output: 783 🏁 Script executed: # Check imports at the top of the file
sed -n '1,30p' python/cuml/cuml/dask/cluster/kmeans.pyRepository: rapidsai/cuml Length of output: 1013 🏁 Script executed: # Now let's examine the helper function in input_utils.py
cat python/cuml/cuml/dask/common/input_utils.py | wc -lRepository: rapidsai/cuml Length of output: 60 🏁 Script executed: # Check the helper function definition
sed -n '170,200p' python/cuml/cuml/dask/common/input_utils.pyRepository: rapidsai/cuml Length of output: 887 🏁 Script executed: # Get more context on the concatenate function
sed -n '175,195p' python/cuml/cuml/dask/common/input_utils.pyRepository: rapidsai/cuml Length of output: 610 Use the imported Line 184 hard-codes Suggested fix- all_labels = [first.labels_] + self.client.gather(remote_labels)
- all_inertias = [first.inertia_] + self.client.gather(remote_inertias)
+ all_labels = [first.labels_, *self.client.gather(remote_labels)]
+ all_inertias = [first.inertia_, *self.client.gather(remote_inertias)]
- first.labels_ = cp.concatenate(all_labels)
+ first.labels_ = concatenate(all_labels, axis=0)
first.inertia_ = sum(all_inertias)🧰 Tools🪛 Ruff (0.15.6)[warning] 181-181: Consider iterable unpacking instead of concatenation Replace with iterable unpacking (RUF005) [warning] 182-182: Consider iterable unpacking instead of concatenation Replace with iterable unpacking (RUF005) 🤖 Prompt for AI Agents |
||
| first.inertia_ = sum(all_inertias) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add regression tests for the new fit aggregation path. This is a behavior-critical bug fix; add coverage for both single-worker and multi-worker fit to assert correct As per coding guidelines 🧰 Tools🪛 Ruff (0.15.6)[warning] 170-172: Add explicit value for parameter (B905) [warning] 176-178: Add explicit value for parameter (B905) [warning] 181-181: Consider iterable unpacking instead of concatenation Replace with iterable unpacking (RUF005) [warning] 182-182: Consider iterable unpacking instead of concatenation Replace with iterable unpacking (RUF005) 🤖 Prompt for AI Agents |
||
| self._set_internal_model(first) | ||
|
|
||
| return self | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for putting up a PR fix so quickly!
May I know how large this remote_labels variable is if dataset has 1 billion rows? Will that blow up scheduler memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
labels_is a 1D array of lengthn_samples(in the dask case, split across N-workers). The dtype is typicallyint32, which brings you to 4 GiB for the array total.In most deployments of
daskthe data doesn't go through the scheduler, it goes directly worker->client (also note that in cases where the scheduler runs on the same node as the client this distinction is meaningless). So you care more about the memory capacity client-side than on the scheduler itself.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying! In my deployment, the scheduler does run on the same node as the client, so they share the same GPU memory. That said, since it's a 1D int32 array, 4 GB seems manageable compared to the previous issue of collecting all workers' copies of the centroid matrix.