diff --git a/dask_ml/_partial.py b/dask_ml/_partial.py index 0bff82b53..0495edf9e 100644 --- a/dask_ml/_partial.py +++ b/dask_ml/_partial.py @@ -7,6 +7,7 @@ import numpy as np import sklearn.utils from dask.delayed import Delayed +from dask.highlevelgraph import HighLevelGraph from toolz import partial logger = logging.getLogger(__name__) @@ -120,19 +121,10 @@ def fit( } ) - graphs = {x_name: x.__dask_graph__(), name: dsk} - if hasattr(y, "__dask_graph__"): - graphs[y_name] = y.__dask_graph__() - - try: - from dask.highlevelgraph import HighLevelGraph - - new_dsk = HighLevelGraph.merge(*graphs.values()) - except ImportError: - from dask import sharedict - - new_dsk = sharedict.merge(*graphs.values()) - + dependencies = [x] + if y is not None: + dependencies.append(y) + new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) value = Delayed((name, nblocks - 1), new_dsk) if compute: