Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions dask_ml/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import sklearn.utils
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from toolz import partial

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -120,19 +121,10 @@ def fit(
}
)

graphs = {x_name: x.__dask_graph__(), name: dsk}
if hasattr(y, "__dask_graph__"):
graphs[y_name] = y.__dask_graph__()

try:
from dask.highlevelgraph import HighLevelGraph

new_dsk = HighLevelGraph.merge(*graphs.values())
except ImportError:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this ImportError branch is no longer needed based on our current lower bound of supported dask versions

dask-ml/setup.py

Lines 13 to 14 in f250432

install_requires = [
"dask[array,dataframe]>=2.4.0",

from dask import sharedict

new_dsk = sharedict.merge(*graphs.values())

dependencies = [x]
if y is not None:
dependencies.append(y)
new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies)
value = Delayed((name, nblocks - 1), new_dsk)

if compute:
Expand Down