Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/cuml/cuml/dask/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import cudf.comm.serialize # noqa: F401
import dask
from dask.delayed import Delayed
from dask_cudf import Series as dcSeries
from distributed.client import Future
from raft_dask.common.comms import Comms
Expand Down Expand Up @@ -304,9 +305,11 @@ def _run_parallel_func(
if output_collection_type is None:
output_collection_type = self.datatype

model_delayed = dask.delayed(
self._get_internal_model(), pure=True, traverse=False
)
model_delayed = self._get_internal_model()
if not isinstance(model_delayed, (Future, Delayed)):
model_delayed = dask.delayed(
model_delayed, pure=True, traverse=False
)
Comment on lines +308 to +312
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a somewhat inefficient implementation. Can we instead modify _get_internal_model() to always return a Future or Delayed?


func = dask.delayed(func, pure=False, nout=1)
if isinstance(X, dcDataFrame):
Expand Down
9 changes: 8 additions & 1 deletion python/cuml/cuml/dask/common/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,14 @@ def calculate_parts_to_sizes(self, comms=None, ranks=None):
for idx, wf in enumerate(self.worker_to_parts.items())
]

sizes = self.client.compute(parts, sync=True)
worker_addresses, futures = zip(*parts)
results = self.client.gather(futures)
sizes = [
(worker_address, result)
for worker_address, result in zip(
worker_addresses, results, strict=True
)
]
Comment on lines +159 to +164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sizes = [
(worker_address, result)
for worker_address, result in zip(
worker_addresses, results, strict=True
)
]
sizes = list(zip(
worker_addresses, results, strict=True
))


for w, sizes_parts in sizes:
sizes, total = sizes_parts
Expand Down
6 changes: 5 additions & 1 deletion python/cuml/cuml/dask/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import dask.array as da
import dask.delayed
from dask.delayed import Delayed
from distributed import Future

from cuml.internals.safe_imports import gpu_only_import

Expand All @@ -36,8 +38,10 @@ def _dask_array_from_delayed(part, dtype, nrows, ncols=None):
# and make an array of shape (nrows, 1)

shape = (nrows, ncols) if ncols else (nrows,)
if not isinstance(part, (Delayed, Future)):
part = dask.delayed(part)
return da.from_delayed(
dask.delayed(part), shape=shape, meta=cp.zeros((1)), dtype=dtype
part, shape=shape, meta=cp.zeros((0,) * len(shape), dtype=dtype)
Comment on lines +41 to +44
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is called once within the entire code base. Maybe we can ensure that it returns a Delayed or Future?

)


Expand Down
Loading