Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ Breaking changes

New Features
~~~~~~~~~~~~
- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`,
:py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`,
:py:meth:`~xarray.Dataset.reindex` (:issue:`3518`).
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Added the ``fill_value`` option to :py:meth:`~xarray.DataArray.unstack` and
:py:meth:`~xarray.Dataset.unstack` (:issue:`3518`).
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def reindex_variables(
tolerance: Any = None,
copy: bool = True,
fill_value: Optional[Any] = dtypes.NA,
sparse: bool = False,
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]:
"""Conform a dictionary of aligned variables onto a new set of variables,
filling in missing values with NaN.
Expand Down Expand Up @@ -503,6 +504,8 @@ def reindex_variables(
the input. In either case, new xarray objects are always returned.
fill_value : scalar, optional
Value to use for newly missing values
sparse: bool, optional
Use an sparse-array

Returns
-------
Expand Down Expand Up @@ -571,6 +574,8 @@ def reindex_variables(

for name, var in variables.items():
if name not in indexers:
if sparse:
var = var._as_sparse(fill_value=fill_value)
key = tuple(
slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None))
for d in var.dims
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ def unstack(
self,
dim: Union[Hashable, Sequence[Hashable], None] = None,
fill_value: Any = dtypes.NA,
sparse: bool = False,
) -> "DataArray":
"""
Unstack existing dimensions corresponding to MultiIndexes into
Expand All @@ -1742,6 +1743,7 @@ def unstack(
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value: value to be filled. By default, np.nan
sparse: use sparse-array if True

Returns
-------
Expand Down Expand Up @@ -1773,7 +1775,7 @@ def unstack(
--------
DataArray.stack
"""
ds = self._to_temp_dataset().unstack(dim, fill_value)
ds = self._to_temp_dataset().unstack(dim, fill_value, sparse)
return self._from_temp_dataset(ds)

def to_unstacked_dataset(self, dim, level=0):
Expand Down
13 changes: 10 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,7 @@ def reindex(
tolerance: Number = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
sparse: bool = False,
**indexers_kwargs: Any,
) -> "Dataset":
"""Conform this object onto a new set of indexes, filling in
Expand Down Expand Up @@ -2286,6 +2287,7 @@ def reindex(
the input. In either case, a new xarray object is always returned.
fill_value : scalar, optional
Value to use for newly missing values
sparse: use sparse-array. By default, False
**indexers_kwarg : {dim: indexer, ...}, optional
Keyword arguments in the same form as ``indexers``.
One of indexers or indexers_kwargs must be provided.
Expand Down Expand Up @@ -2444,6 +2446,7 @@ def reindex(
tolerance,
copy=copy,
fill_value=fill_value,
sparse=sparse,
)
coord_names = set(self._coord_names)
coord_names.update(indexers)
Expand Down Expand Up @@ -3333,7 +3336,7 @@ def ensure_stackable(val):

return data_array

def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
index = self.get_index(dim)
index = index.remove_unused_levels()
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
Expand All @@ -3342,7 +3345,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
if index.equals(full_idx):
obj = self
else:
obj = self.reindex({dim: full_idx}, copy=False, fill_value=fill_value)
obj = self.reindex(
{dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse
)

new_dim_names = index.names
new_dim_sizes = [lev.size for lev in index.levels]
Expand Down Expand Up @@ -3372,6 +3377,7 @@ def unstack(
self,
dim: Union[Hashable, Iterable[Hashable]] = None,
fill_value: Any = dtypes.NA,
sparse: bool = False,
) -> "Dataset":
"""
Unstack existing dimensions corresponding to MultiIndexes into
Expand All @@ -3385,6 +3391,7 @@ def unstack(
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value: value to be filled. By default, np.nan
sparse: use sparse-array if True

Returns
-------
Expand Down Expand Up @@ -3422,7 +3429,7 @@ def unstack(

result = self.copy(deep=False)
for dim in dims:
result = result._unstack_once(dim, fill_value)
result = result._unstack_once(dim, fill_value, sparse)
return result

def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset":
Expand Down
7 changes: 6 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

from . import dask_array_ops, dtypes, npcompat, nputils
from .nputils import nanfirst, nanlast
from .pycompat import dask_array_type
from .pycompat import dask_array_type, sparse_array_type

try:
import dask.array as dask_array
import sparse
except ImportError:
dask_array = None # type: ignore
sparse = None # type: ignore


def _dask_or_eager_func(
Expand Down Expand Up @@ -251,6 +253,9 @@ def count(data, axis=None):

def where(condition, x, y):
"""Three argument where() with better dtype promotion rules."""
# sparse support
if isinstance(x, sparse_array_type) or isinstance(y, sparse_array_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little surprised this is necessary. Does sparse not support __array_function__ for np.where?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes. sparse looks not working with np.result_type and astype(copy=False).
I'll add a TODO here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have the latest version of sparse installed?

when I test this on my machine, it works:

In [13]: import sparse

In [14]: import numpy as np

In [15]: import xarray

In [16]: x = sparse.COO(np.arange(3))

In [17]: xarray.core.duck_array_ops.where(x > 1, x, x)
Out[17]: <COO: shape=(3,), dtype=int64, nnz=2, fill_value=0>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. You are right.
I was running with sparse 0.7.0. With 0.8.0, it is running.

return sparse.where(condition, x, y)
return _where(condition, *as_shared_dtype([x, y]))


Expand Down
34 changes: 34 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,32 @@ def chunk(self, chunks=None, name=None, lock=False):

return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)

def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this is a private method.
Probably we can expose it to the public and add the same method to DataArray and Dataset as well in the future.

"""
use sparse-array as backend.
"""
import sparse

# TODO what to do if dask-backended?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully sparse will raise an error if you try to convert a dask array into a sparse array! If not, we should do that ourselves.

Long term, the best solution would be to convert a dask array from dense chunks to sparse chunks.

if fill_value is dtypes.NA:
dtype, fill_value = dtypes.maybe_promote(self.dtype)
else:
dtype = self.dtype

if sparse_format is _default:
sparse_format = "coo"
as_sparse = getattr(sparse, "as_{}".format(sparse_format.lower()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of not hard-coding supported sparse formats, but I wonder if we could be a little more careful here if AttributeError is raised. We should probably catch and re-raise Attribute error with a more informative message if this fails.

Otherwise, I expect we might see bug reports from confused users, e.g., when sparse_format='csr' raises a confusing message.

data = as_sparse(self.data.astype(dtype), fill_value=fill_value)
return self._replace(data=data)

def _to_dense(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also private, as is _as_sparse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make these public in DataArray and Dataset. See discussion here: #3245 (comment)

Can be left for a future PR though :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @dcherian.
I would like to expose them to public, but what is the best name of these functions?
#3245

"""
Change backend from sparse to np.array
"""
if hasattr(self._data, "todense"):
return self._replace(data=self._data.todense())
return self.copy(deep=False)

def isel(
self: VariableType,
indexers: Mapping[Hashable, Any] = None,
Expand Down Expand Up @@ -2021,6 +2047,14 @@ def chunk(self, chunks=None, name=None, lock=False):
# Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk()
return self.copy(deep=False)

def _as_sparse(self, sparse_format=_default, fill_value=_default):
# Dummy
return self.copy(deep=False)

def _to_dense(self):
# Dummy
return self.copy(deep=False)

def _finalize_indexing_result(self, dims, data):
if getattr(data, "ndim", 0) != 1:
# returns Variable rather than IndexVariable if multi-dimensional
Expand Down
36 changes: 36 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,23 @@ def test_reindex(self):
actual = ds.reindex(x=[0, 1, 3], y=[0, 1])
assert_identical(expected, actual)

@requires_sparse
def test_reindex_sparse(self):
data = create_test_data()
dim3 = list("abdeghijk")
actual = data.reindex(dim3=dim3, sparse=True)
expected = data.reindex(dim3=dim3, sparse=False)
for k, v in data.data_vars.items():
np.testing.assert_equal(actual[k].data.todense(), expected[k].data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, assert_equal cannot be used as we need to explicitly densify the array for the comparison.

assert actual["var3"].data.density < 1.0

data["var3"] = data["var3"].astype(int)
actual = data.reindex(dim3=dim3, sparse=True, fill_value=-10)
expected = data.reindex(dim3=dim3, sparse=False, fill_value=-10)
for k, v in data.data_vars.items():
np.testing.assert_equal(actual[k].data.todense(), expected[k].data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we generally use assert_array_equal for numpy arrays (but I can't immediately recall the difference...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these actually end up doing the exact same checks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

values property does not work for the sparse-backed Variable, resulting in the failure of assert_array_equal.
I'll add TODO comment for this.

assert actual["var3"].data.density < 1.0

def test_reindex_warning(self):
data = create_test_data()

Expand Down Expand Up @@ -2811,6 +2828,25 @@ def test_unstack_fill_value(self):
expected = ds["var"].unstack("index").fillna(-1).astype(np.int)
assert actual.equals(expected)

@requires_sparse
def test_unstack_sparse(self):
ds = xr.Dataset(
{"var": (("x",), np.arange(6))},
coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)},
)
# make ds incomplete
ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"])
# test fill_value
actual = ds.unstack("index", sparse=True)
expected = ds.unstack("index")
assert actual["var"].variable._to_dense().equals(expected["var"].variable)
assert actual["var"].data.density < 1.0

actual = ds["var"].unstack("index", sparse=True)
expected = ds["var"].unstack("index")
assert actual.variable._to_dense().equals(expected.variable)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we test whether actual.variable is actually sparse?

assert actual.data.density < 1.0

def test_stack_unstack_fast(self):
ds = Dataset(
{
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
assert_identical,
raises_regex,
requires_dask,
requires_sparse,
source_ndarray,
)

Expand Down Expand Up @@ -1862,6 +1863,17 @@ def test_getitem_with_mask_nd_indexer(self):
)


@requires_sparse
class TestVariableWithSparse:
# TODO inherit VariableSubclassobjects to cover more tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def test_as_sparse(self):
data = np.arange(12).reshape(3, 4)
var = Variable(("x", "y"), data)._as_sparse(fill_value=-1)
actual = var._to_dense()
assert_identical(var, actual)


class TestIndexVariable(VariableSubclassobjects):
cls = staticmethod(IndexVariable)

Expand Down