Skip to content

Fix OOM in Dask KMeans by collecting only one model after fit#7908

Merged
rapids-bot[bot] merged 7 commits intorapidsai:release/26.04from
viclafargue:fix-oom-dask-kmeans
Mar 19, 2026
Merged

Fix OOM in Dask KMeans by collecting only one model after fit#7908
rapids-bot[bot] merged 7 commits intorapidsai:release/26.04from
viclafargue:fix-oom-dask-kmeans

Conversation

@viclafargue
Copy link
Copy Markdown
Contributor

After MNMG KMeans fit, the client was calling .result() on every worker's future, pulling the full fitted estimator (including cluster_centers_) from all workers back to the client. Since cluster centers are synchronized across workers via NCCL, every copy is identical making all but one transfer redundant.

This PR changes the post-fit aggregation to:

  • Collect the full model from only the first worker
  • Extract labels_ and inertia_ from the remaining workers remotely via client.submit(getattr, ...)

@viclafargue viclafargue self-assigned this Mar 18, 2026
@viclafargue viclafargue requested a review from a team as a code owner March 18, 2026 13:39
@viclafargue viclafargue requested a review from betatim March 18, 2026 13:39
@viclafargue viclafargue added bug Something isn't working non-breaking Non-breaking change labels Mar 18, 2026
@github-actions github-actions Bot added the Cython / Python Cython or Python issue label Mar 18, 2026
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Mar 18, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

fit() now fetches the full model only from the first worker, adds a module helper _get_inertia_and_n_samples(estimator), aggregates total inertia_ by summing per-worker inertia scalars on the scheduler, and builds labels_ as a distributed dask.array by concatenating da.from_delayed chunks with explicit chunk shapes.

Changes

Cohort / File(s) Summary
Dask KMeans Aggregation & Distributed Labels
python/cuml/cuml/dask/cluster/kmeans.py
Added _get_inertia_and_n_samples(estimator); reworked KMeans.fit() to: retrieve full model only from first worker; gather per-worker (inertia, n_samples) and sum inertia scalars on scheduler; construct labels_ as a dask.array via da.from_delayed chunks and da.concatenate using explicit chunk shapes; removed client-side materialization and concatenation of all per-worker models/labels.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

improvement

Suggested reviewers

jcrist

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: fixing OOM in Dask KMeans by collecting only one model after fit, which aligns with the core objective of reducing redundant full-model transfers.
Description check ✅ Passed The description is directly related to the changeset, explaining the problem (redundant transfers of cluster_centers_) and the solution (collecting only the first model and extracting labels/inertia remotely).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/cuml/cuml/dask/cluster/kmeans.py`:
- Around line 181-184: Replace the hard-coded cp.concatenate call with the
imported concatenate helper to preserve input type and avoid unnecessary copies:
where the code builds all_labels and sets first.labels_ (currently using
cp.concatenate(all_labels)), call concatenate(all_labels) instead (the helper is
imported from cuml.dask.common.input_utils) so cuDF/CuPy/NumPy inputs keep their
original types and the single-shard path avoids an extra allocation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b5a7e78e-fd3f-454d-bb6a-3d93b39be31b

📥 Commits

Reviewing files that changed from the base of the PR and between 95b7be3 and 06894d4.

📒 Files selected for processing (1)
  • python/cuml/cuml/dask/cluster/kmeans.py

Comment thread python/cuml/cuml/dask/cluster/kmeans.py Outdated
Comment on lines +181 to +184
all_labels = [first.labels_] + self.client.gather(remote_labels)
all_inertias = [first.inertia_] + self.client.gather(remote_inertias)

first.labels_ = cp.concatenate(all_labels)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the target file and the specific lines mentioned
head -n 20 python/cuml/cuml/dask/cluster/kmeans.py | tail -n 10

Repository: rapidsai/cuml

Length of output: 404


🏁 Script executed:

# Check the full context around lines 181-184
sed -n '1,20p' python/cuml/cuml/dask/cluster/kmeans.py

Repository: rapidsai/cuml

Length of output: 692


🏁 Script executed:

# Check if the file exists and get line count
wc -l python/cuml/cuml/dask/cluster/kmeans.py

Repository: rapidsai/cuml

Length of output: 100


🏁 Script executed:

# Read the relevant section around lines 181-184
sed -n '170,190p' python/cuml/cuml/dask/cluster/kmeans.py

Repository: rapidsai/cuml

Length of output: 783


🏁 Script executed:

# Check imports at the top of the file
sed -n '1,30p' python/cuml/cuml/dask/cluster/kmeans.py

Repository: rapidsai/cuml

Length of output: 1013


🏁 Script executed:

# Now let's examine the helper function in input_utils.py
cat python/cuml/cuml/dask/common/input_utils.py | wc -l

Repository: rapidsai/cuml

Length of output: 60


🏁 Script executed:

# Check the helper function definition
sed -n '170,200p' python/cuml/cuml/dask/common/input_utils.py

Repository: rapidsai/cuml

Length of output: 887


🏁 Script executed:

# Get more context on the concatenate function
sed -n '175,195p' python/cuml/cuml/dask/common/input_utils.py

Repository: rapidsai/cuml

Length of output: 610


Use the imported concatenate() helper to preserve input type and avoid unnecessary copies.

Line 184 hard-codes cp.concatenate(), which causes type coercion when labels_ is a cuDF object and allocates a new buffer even in the single-worker case. The already-imported concatenate() from cuml.dask.common.input_utils preserves the input type (cuDF, CuPy, or NumPy) and skips the copy when there is only one shard.

Suggested fix
-        all_labels = [first.labels_] + self.client.gather(remote_labels)
-        all_inertias = [first.inertia_] + self.client.gather(remote_inertias)
+        all_labels = [first.labels_, *self.client.gather(remote_labels)]
+        all_inertias = [first.inertia_, *self.client.gather(remote_inertias)]

-        first.labels_ = cp.concatenate(all_labels)
+        first.labels_ = concatenate(all_labels, axis=0)
         first.inertia_ = sum(all_inertias)
🧰 Tools
🪛 Ruff (0.15.6)

[warning] 181-181: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


[warning] 182-182: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/dask/cluster/kmeans.py` around lines 181 - 184, Replace the
hard-coded cp.concatenate call with the imported concatenate helper to preserve
input type and avoid unnecessary copies: where the code builds all_labels and
sets first.labels_ (currently using cp.concatenate(all_labels)), call
concatenate(all_labels) instead (the helper is imported from
cuml.dask.common.input_utils) so cuDF/CuPy/NumPy inputs keep their original
types and the single-shard path avoids an extra allocation.

@viclafargue viclafargue changed the base branch from main to release/26.04 March 18, 2026 14:16
@viclafargue viclafargue requested review from a team as code owners March 18, 2026 14:16
@viclafargue viclafargue requested a review from AyodeAwe March 18, 2026 14:16
@viclafargue viclafargue force-pushed the fix-oom-dask-kmeans branch from 06894d4 to e416630 Compare March 18, 2026 14:19
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
python/cuml/cuml/dask/cluster/kmeans.py (1)

181-185: ⚠️ Potential issue | 🟠 Major

Use concatenate() helper for labels_ aggregation instead of cp.concatenate().

Line 184 hard-codes CuPy concatenation, which can break output-type preservation and misses the single-shard no-copy path already implemented in cuml.dask.common.input_utils.concatenate().

Suggested fix
-        all_labels = [first.labels_] + self.client.gather(remote_labels)
-        all_inertias = [first.inertia_] + self.client.gather(remote_inertias)
+        all_labels = [first.labels_, *self.client.gather(remote_labels)]
+        all_inertias = [first.inertia_, *self.client.gather(remote_inertias)]

-        first.labels_ = cp.concatenate(all_labels)
+        first.labels_ = concatenate(all_labels, axis=0)
         first.inertia_ = sum(all_inertias)

Based on learnings: Correctly handle cuDF, pandas, and NumPy inputs using input_to_cuml_array() for consistent conversion; preserve input type in output where sensible; handle both row-major (C) and column-major (F) memory order.

#!/bin/bash
# Verify current aggregation API usage in the changed block
rg -n "all_labels|all_inertias|cp\\.concatenate|concatenate\\(" python/cuml/cuml/dask/cluster/kmeans.py -C 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/dask/cluster/kmeans.py` around lines 181 - 185, Replace the
hard-coded CuPy concatenation of labels with the shared concatenate helper so
output types and single-shard no-copy behavior are preserved: instead of using
cp.concatenate on all_labels, call the concatenate function from
cuml.dask.common.input_utils (the same helper used elsewhere) to aggregate the
result of self.client.gather(remote_labels) together with first.labels_; keep
inertia aggregation as sum(all_inertias) but ensure labels flow remains through
input_to_cuml_array/concatenate path to correctly handle cuDF, pandas and NumPy
inputs and both C/F memory orders.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/cuml/cuml/dask/cluster/kmeans.py`:
- Around line 162-185: Add regression tests that exercise the new fit
aggregation path used in kmeans (the block that builds `first =
kmeans_fit[0].result()`, then remotely fetches `labels_`/`inertia_` via
`self.client.submit(getattr, f, "labels_", ...)` / `getattr(..., "inertia_",
...)`, gathers them with `self.client.gather`, and concatenates/sums into
`first.labels_`/`first.inertia_`). Create two tests: one that runs fit on a
single-worker Dask cluster and verifies `labels_` and `inertia_` match the
non-distributed scikit-learn/cuml reference, and one that runs on a multi-worker
cluster and asserts the aggregated `labels_` (after `cp.concatenate`) and
aggregated `inertia_` (after sum) equal the expected global values and that only
the first worker materializes the full model (no redundant cluster_centers_
copies). Ensure tests use the same api symbols (`kmeans_fit`, `first`,
`labels_`, `inertia_`, `client.gather`) so future changes to aggregation logic
are covered.

---

Duplicate comments:
In `@python/cuml/cuml/dask/cluster/kmeans.py`:
- Around line 181-185: Replace the hard-coded CuPy concatenation of labels with
the shared concatenate helper so output types and single-shard no-copy behavior
are preserved: instead of using cp.concatenate on all_labels, call the
concatenate function from cuml.dask.common.input_utils (the same helper used
elsewhere) to aggregate the result of self.client.gather(remote_labels) together
with first.labels_; keep inertia aggregation as sum(all_inertias) but ensure
labels flow remains through input_to_cuml_array/concatenate path to correctly
handle cuDF, pandas and NumPy inputs and both C/F memory orders.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 5098c120-0950-4748-9d71-c734cc2e3d8d

📥 Commits

Reviewing files that changed from the base of the PR and between 06894d4 and e416630.

📒 Files selected for processing (1)
  • python/cuml/cuml/dask/cluster/kmeans.py

Comment thread python/cuml/cuml/dask/cluster/kmeans.py Outdated
Comment on lines +162 to +185
# Collect the full model from only the first worker (for
# cluster_centers_ etc). Extract labels_ and inertia_ from the
# remaining workers remotely to avoid pulling N redundant copies
# of cluster_centers_ back to the client.
first = kmeans_fit[0].result()

remote_labels = [
self.client.submit(getattr, f, "labels_", workers=[w])
for f, (w, _) in zip(
kmeans_fit[1:], list(data.worker_to_parts.items())[1:]
)
]
remote_inertias = [
self.client.submit(getattr, f, "inertia_", workers=[w])
for f, (w, _) in zip(
kmeans_fit[1:], list(data.worker_to_parts.items())[1:]
)
]

all_labels = [first.labels_] + self.client.gather(remote_labels)
all_inertias = [first.inertia_] + self.client.gather(remote_inertias)

first.labels_ = cp.concatenate(all_labels)
first.inertia_ = sum(all_inertias)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Please add regression tests for the new fit aggregation path.

This is a behavior-critical bug fix; add coverage for both single-worker and multi-worker fit to assert correct labels_/inertia_ aggregation and prevent reintroducing redundant model materialization.

As per coding guidelines **/*.py: Update unit tests when making code changes.

🧰 Tools
🪛 Ruff (0.15.6)

[warning] 170-172: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


[warning] 176-178: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


[warning] 181-181: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


[warning] 182-182: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/dask/cluster/kmeans.py` around lines 162 - 185, Add
regression tests that exercise the new fit aggregation path used in kmeans (the
block that builds `first = kmeans_fit[0].result()`, then remotely fetches
`labels_`/`inertia_` via `self.client.submit(getattr, f, "labels_", ...)` /
`getattr(..., "inertia_", ...)`, gathers them with `self.client.gather`, and
concatenates/sums into `first.labels_`/`first.inertia_`). Create two tests: one
that runs fit on a single-worker Dask cluster and verifies `labels_` and
`inertia_` match the non-distributed scikit-learn/cuml reference, and one that
runs on a multi-worker cluster and asserts the aggregated `labels_` (after
`cp.concatenate`) and aggregated `inertia_` (after sum) equal the expected
global values and that only the first worker materializes the full model (no
redundant cluster_centers_ copies). Ensure tests use the same api symbols
(`kmeans_fit`, `first`, `labels_`, `inertia_`, `client.gather`) so future
changes to aggregation logic are covered.

Comment thread python/cuml/cuml/dask/cluster/kmeans.py Outdated
Comment on lines +168 to +182
remote_labels = [
self.client.submit(getattr, f, "labels_", workers=[w])
for f, (w, _) in zip(
kmeans_fit[1:], list(data.worker_to_parts.items())[1:]
)
]
remote_inertias = [
self.client.submit(getattr, f, "inertia_", workers=[w])
for f, (w, _) in zip(
kmeans_fit[1:], list(data.worker_to_parts.items())[1:]
)
]

all_labels = [first.labels_] + self.client.gather(remote_labels)
all_inertias = [first.inertia_] + self.client.gather(remote_inertias)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an overhead to tasks - it'd be better to submit one task that returns a tuple of (labels, inertia) than 2 simple tasks.

Comment thread python/cuml/cuml/dask/cluster/kmeans.py Outdated
# of cluster_centers_ back to the client.
first = kmeans_fit[0].result()

remote_labels = [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for putting up a PR fix so quickly!

May I know how large this remote_labels variable is if dataset has 1 billion rows? Will that blow up scheduler memory?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

labels_ is a 1D array of length n_samples (in the dask case, split across N-workers). The dtype is typically int32, which brings you to 4 GiB for the array total.

In most deployments of dask the data doesn't go through the scheduler, it goes directly worker->client (also note that in cases where the scheduler runs on the same node as the client this distinction is meaningless). So you care more about the memory capacity client-side than on the scheduler itself.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying! In my deployment, the scheduler does run on the same node as the client, so they share the same GPU memory. That said, since it's a 1D int32 array, 4 GB seems manageable compared to the previous issue of collecting all workers' copies of the centroid matrix.

@jcrist
Copy link
Copy Markdown
Member

jcrist commented Mar 19, 2026

If we were to redesign this estimator (and probably the other clustering estimators) today, I might reconsider making the labels_ attribute a cupy array here and instead leave it as a dask array.

It's not needed for prediction, it's basically an artifact of model.fit_predict(X) stored on the estimator itself. And since this attribute scales with n_samples (rather than n_features), forcing it to be pulled all to one node is suboptimal. Further, the output of fit_predict itself is a dask array, so it's really only the attribute itself that's forcing this behavior.

A few options:

  • Leave things as they are - the fix in this PR reduces the memory usage somewhat.
  • Make the attribute a cached property, and don't compute it by default. This means for typical usage the labels will never be pulled to a single node, but isn't a breaking change for the interface.
  • Move labels_ to be a dask.array value instead (this is what I would do if I were designing this class fresh today, and fits much better with dask's conventions). I'm not sure of a clean way to deprecate towards this, but since the labels_ attribute itself is fully undocumented on the class, maybe we don't care about that.

I'd vote for the third option if possible. It's the cleanest, most idiomatic dask interface, and would avoid this inefficiency.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
python/cuml/cuml/dask/cluster/kmeans.py (1)

175-182: Add strict=True to zip() calls.

Both zip() calls iterate over slices that should always have matching lengths. Adding strict=True guards against subtle bugs if the data structures ever diverge.

Suggested fix
         remote_labels = [
             self.client.submit(getattr, f, "labels_", workers=[w])
-            for f, w in zip(kmeans_fit[1:], workers[1:])
+            for f, w in zip(kmeans_fit[1:], workers[1:], strict=True)
         ]
         remote_inertias = [
             self.client.submit(getattr, f, "inertia_", workers=[w])
-            for f, w in zip(kmeans_fit[1:], workers[1:])
+            for f, w in zip(kmeans_fit[1:], workers[1:], strict=True)
         ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/dask/cluster/kmeans.py` around lines 175 - 182, The two list
comprehensions building remote_labels and remote_inertias use
zip(kmeans_fit[1:], workers[1:]) without strict checking; modify both zip calls
in the remote_labels and remote_inertias comprehensions to use zip(...,
strict=True) so mismatched lengths between kmeans_fit[1:] and workers[1:] raise
immediately; update the zip invocations used when calling
self.client.submit(getattr, f, "labels_", workers=[w]) and
self.client.submit(getattr, f, "inertia_", workers=[w]) to include strict=True.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@python/cuml/cuml/dask/cluster/kmeans.py`:
- Around line 175-182: The two list comprehensions building remote_labels and
remote_inertias use zip(kmeans_fit[1:], workers[1:]) without strict checking;
modify both zip calls in the remote_labels and remote_inertias comprehensions to
use zip(..., strict=True) so mismatched lengths between kmeans_fit[1:] and
workers[1:] raise immediately; update the zip invocations used when calling
self.client.submit(getattr, f, "labels_", workers=[w]) and
self.client.submit(getattr, f, "inertia_", workers=[w]) to include strict=True.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e71d94ca-7639-4a94-8f63-6eaa0a32ac02

📥 Commits

Reviewing files that changed from the base of the PR and between e416630 and 1acf402.

📒 Files selected for processing (1)
  • python/cuml/cuml/dask/cluster/kmeans.py

@viclafargue
Copy link
Copy Markdown
Contributor Author

@jcrist, the first commit extracts labels and inertia via a single unified Dask task. The second commit takes this further by keeping labels_ as a distributed dask.array. The first estimator's labels and the remaining workers' labels are wrapped as delayed chunks rather than gathered to the client to form a cuPy array, while inertia is still collected and summed locally. If CI passes cleanly, we avoid transferring the labels to the client entirely. Otherwise, we can revert this last commit and fall back to the first.

Comment thread python/cuml/cuml/dask/cluster/kmeans.py Outdated
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
python/cuml/cuml/dask/cluster/kmeans.py (1)

175-182: Add strict=True to zip() calls for defensive validation.

Both kmeans_fit[1:] and workers[1:] should always have the same length, but adding strict=True catches mismatches early if future refactoring breaks that invariant.

Also, per the earlier reviewer suggestion: submitting a single task returning (labels_, inertia_) instead of two separate getattr tasks per worker would reduce scheduler overhead.

Proposed fix for strict= parameter
         remote_labels = [
             self.client.submit(getattr, f, "labels_", workers=[w])
-            for f, w in zip(kmeans_fit[1:], workers[1:])
+            for f, w in zip(kmeans_fit[1:], workers[1:], strict=True)
         ]
         remote_inertias = [
             self.client.submit(getattr, f, "inertia_", workers=[w])
-            for f, w in zip(kmeans_fit[1:], workers[1:])
+            for f, w in zip(kmeans_fit[1:], workers[1:], strict=True)
         ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/dask/cluster/kmeans.py` around lines 175 - 182, The zip of
kmeans_fit[1:] and workers[1:] in the remote_labels and remote_inertias creation
is unsafe against length mismatches; update the zip calls to zip(kmeans_fit[1:],
workers[1:], strict=True) to validate lengths and fail fast. While here, reduce
scheduler overhead by replacing the two separate self.client.submit(getattr, f,
"labels_", ...) and self.client.submit(getattr, f, "inertia_", ...) calls with a
single self.client.submit that returns (f.labels_, f.inertia_) (i.e., submit a
single task per worker that retrieves both "labels_" and "inertia_"), then
unpack to build remote_labels and remote_inertias from that single tuple result.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@python/cuml/cuml/dask/cluster/kmeans.py`:
- Around line 175-182: The zip of kmeans_fit[1:] and workers[1:] in the
remote_labels and remote_inertias creation is unsafe against length mismatches;
update the zip calls to zip(kmeans_fit[1:], workers[1:], strict=True) to
validate lengths and fail fast. While here, reduce scheduler overhead by
replacing the two separate self.client.submit(getattr, f, "labels_", ...) and
self.client.submit(getattr, f, "inertia_", ...) calls with a single
self.client.submit that returns (f.labels_, f.inertia_) (i.e., submit a single
task per worker that retrieves both "labels_" and "inertia_"), then unpack to
build remote_labels and remote_inertias from that single tuple result.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 3991dd39-b55a-499a-a516-fb694a452423

📥 Commits

Reviewing files that changed from the base of the PR and between 1acf402 and dfd461a.

📒 Files selected for processing (1)
  • python/cuml/cuml/dask/cluster/kmeans.py

@jcrist
Copy link
Copy Markdown
Member

jcrist commented Mar 19, 2026

I pushed a small fixup commit with some improvements for dask idioms. The biggest one is that we don't want to store the labels_ array from the first partition in the graph of the created labels_ dask array. That would cause any downstream usage of the labels_ array to re-serialize the partition's labels to a worker, when it already exists on the cluster. Better to just reference the data that's already on the cluster.

I also tried to fix what I viewed as a performance bug in fit_predict (where it re-runs the predict call, when the output should really be the existing labels_ array). However, this turned up what I think is a bug - namely that in the distributed versions model.fit_predict(X) and model.fit(X).predict(X) are not equivalent. I'll open a followup issue for this.

Anyway, I believe the actual memory issue here is resolved for now. Happy to merge.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
python/cuml/cuml/dask/cluster/kmeans.py (2)

172-197: ⚠️ Potential issue | 🟠 Major

internal_model now exposes shard-local labels_ and inertia_.

After Line 172, the global self.inertia_ and distributed self.labels_ are written only on the Dask estimator; the stored combined model still keeps the first worker’s local values. Any get_combined_model() or other internal-model consumer can now observe fitted state that disagrees with self. Please sync those attributes when materializing the combined model, or override the accessor to source them from the Dask estimator.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/dask/cluster/kmeans.py` around lines 172 - 197, The
internal_model stored via _set_internal_model(first) still contains shard-local
labels_ and inertia_, causing get_combined_model() and other consumers to see
stale local values; update the combined/internal model after you compute the
global self.inertia_ and distributed self.labels_ so they reflect the Dask
estimator: after computing inertia_and_lengths and constructing self.labels_,
assign those consolidated values into the internal model (e.g.,
internal_model.labels_ = self.labels_ and internal_model.inertia_ =
self.inertia_) or alter the internal_model accessor to delegate to the Dask
estimator's self.labels_ and self.inertia_; ensure you update the same object
returned by _set_internal_model(first) so consumers of get_combined_model()
observe the synced state.

167-172: ⚠️ Potential issue | 🟠 Major

Keep internal_model as a future instead of calling .result() here.

Line 171 still brings one whole fitted estimator back to the client, including that worker’s labels_ shard. For large or skewed partitions this can still be the dominant allocation and reintroduce the OOM the PR is trying to remove. BaseEstimator._set_internal_model() already accepts futures, so store kmeans_fit[0] directly and fetch only lightweight metadata (for example the labels dtype) with small helper tasks.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml/dask/cluster/kmeans.py` around lines 167 - 172, The code
currently calls .result() on kmeans_fit[0] and passes the full estimator to
_set_internal_model, pulling the entire fitted estimator (including labels
shard) to the client; instead keep internal_model as the future by passing
kmeans_fit[0] directly to BaseEstimator._set_internal_model, and create a small
helper task that runs on the worker to extract only lightweight metadata needed
on the client (e.g., labels dtype, n_clusters) from the future without
materializing the full model locally; update the logic around kmeans_fit and any
subsequent uses of internal_model to expect a future and use the helper task
results for metadata consumption.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@python/cuml/cuml/dask/cluster/kmeans.py`:
- Around line 172-197: The internal_model stored via _set_internal_model(first)
still contains shard-local labels_ and inertia_, causing get_combined_model()
and other consumers to see stale local values; update the combined/internal
model after you compute the global self.inertia_ and distributed self.labels_ so
they reflect the Dask estimator: after computing inertia_and_lengths and
constructing self.labels_, assign those consolidated values into the internal
model (e.g., internal_model.labels_ = self.labels_ and internal_model.inertia_ =
self.inertia_) or alter the internal_model accessor to delegate to the Dask
estimator's self.labels_ and self.inertia_; ensure you update the same object
returned by _set_internal_model(first) so consumers of get_combined_model()
observe the synced state.
- Around line 167-172: The code currently calls .result() on kmeans_fit[0] and
passes the full estimator to _set_internal_model, pulling the entire fitted
estimator (including labels shard) to the client; instead keep internal_model as
the future by passing kmeans_fit[0] directly to
BaseEstimator._set_internal_model, and create a small helper task that runs on
the worker to extract only lightweight metadata needed on the client (e.g.,
labels dtype, n_clusters) from the future without materializing the full model
locally; update the logic around kmeans_fit and any subsequent uses of
internal_model to expect a future and use the helper task results for metadata
consumption.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2b2691dd-1950-4f51-b4c5-373142fad777

📥 Commits

Reviewing files that changed from the base of the PR and between dfd461a and fe977b7.

📒 Files selected for processing (1)
  • python/cuml/cuml/dask/cluster/kmeans.py

Copy link
Copy Markdown
Contributor Author

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change looks good to me. We should consider having the attributes that scale with n_samples as Dask arrays in other Dask estimators too.

Comment thread python/cuml/cuml/dask/cluster/kmeans.py Outdated
@jcrist
Copy link
Copy Markdown
Member

jcrist commented Mar 19, 2026

I also tried to fix what I viewed as a performance bug in fit_predict (where it re-runs the predict call, when the output should really be the existing labels_ array). However, this turned up what I think is a bug - namely that in the distributed versions model.fit_predict(X) and model.fit(X).predict(X) are not equivalent. I'll open a followup issue for this.

Nevermind, this was an actual bug in the implementation for cudf inputs. I've fixed this, and added a test. fit_predict now returns self.labels_ as expected, avoiding the extra cost of an unnecessary predict call.

@jcrist
Copy link
Copy Markdown
Member

jcrist commented Mar 19, 2026

/merge

@rapids-bot rapids-bot Bot merged commit 2e65466 into rapidsai:release/26.04 Mar 19, 2026
172 of 174 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working Cython / Python Cython or Python issue non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants