Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions sparsity/dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
no_default, partial, partial_by_order,
split_evenly, check_divisions, hash_shard,
split_out_on_index, Index)
from dask.dataframe.groupby import _apply_chunk
from dask.dataframe.utils import _nonempty_index, make_meta
from dask.dataframe.utils import _nonempty_index
from dask.dataframe.utils import make_meta as dd_make_meta
from dask.delayed import Delayed
from dask.optimize import cull
from dask.optimization import cull
from dask.utils import derived_from
from scipy import sparse
from toolz import merge, remove, partition_all
Expand Down Expand Up @@ -65,9 +64,6 @@ def __init__(self, dsk, name, meta, divisions=None):
def __dask_graph__(self):
return self.dask

def __dask_keys__(self):
return self._keys()

__dask_scheduler__ = staticmethod(dask.threaded.get)

@staticmethod
Expand Down Expand Up @@ -112,7 +108,7 @@ def map_partitions(self, func, meta, *args, **kwargs):
return map_partitions(func, self, meta, *args, **kwargs)

def to_delayed(self):
return [Delayed(k, self.dask) for k in self._keys()]
return [Delayed(k, self.dask) for k in self.__dask_keys__()]

def assign(self, **kwargs):
for k, v in kwargs.items():
Expand All @@ -126,7 +122,7 @@ def assign(self, **kwargs):
df2 = self._meta.assign(**_extract_meta(kwargs))
return elemwise(methods.assign, self, *pairs, meta=df2)

def _keys(self):
def __dask_keys__(self):
return [(self._name, i) for i in range(self.npartitions)]

@property
Expand Down Expand Up @@ -572,8 +568,8 @@ def elemwise(op, *args, **kwargs):
if not isinstance(arg, (_Frame, Scalar, SparseFrame))]

# Get dsks graph tuple keys and adjust the key length of Scalar
keys = [d._keys() * n if isinstance(d, Scalar) or _is_broadcastable(d)
else d._keys() for d in dasks]
keys = [d.__dask_keys__() * n if isinstance(d, Scalar) or _is_broadcastable(d)
else d.__dask_keys__() for d in dasks]

if other:
dsk = {(_name, i):
Expand Down