Skip to content
Merged
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ Internal Changes

- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.ffill` to
use non-dask chunked array types.
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.

.. _whats-new.2023.07.0:

Expand Down
39 changes: 0 additions & 39 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,42 +53,3 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
# See issue dask/dask#6516
coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
return coeffs, residuals


def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import bottleneck
import dask.array as da
import numpy as np

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)

if n is not None and 0 < n < array.shape[axis] - 1:
arange = da.broadcast_to(
da.arange(
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
).reshape(
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
),
array.shape,
array.chunks,
)
valid_arange = da.where(da.notnull(array), arange, np.nan)
valid_limits = (arange - push(valid_arange, None, axis)) <= n
# omit the forward fill that violate the limit
return da.where(valid_limits, push(array, None, axis), np.nan)

# The method parameter makes that the tests for python 3.7 fails.
return da.reductions.cumreduction(
func=bottleneck.push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
)
22 changes: 22 additions & 0 deletions xarray/core/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def reduction(
keepdims=keepdims,
)

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> DaskArray:
from dask.array.reductions import cumreduction

return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
**kwargs,
)

def apply_gufunc(
self,
func: Callable,
Expand Down
47 changes: 45 additions & 2 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,54 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)


def _chunked_push(array, n, axis):
"""
Chunk-aware bottleneck.push
"""
import bottleneck
import numpy as np

chunkmanager = get_chunked_array_type(array)
xp = chunkmanager.array_api

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)

if n is not None and 0 < n < array.shape[axis] - 1:
arange = xp.arange(
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
)
broadcasted_arange = xp.broadcast_to(
xp.reshape(
arange,
tuple(size if i == axis else 1 for i, size in enumerate(array.shape)),
),
array.shape,
array.chunks,
)
valid_arange = xp.where(xp.notnull(array), broadcasted_arange, np.nan)
valid_limits = (arange - push(valid_arange, None, axis)) <= n
# omit the forward fill that violate the limit
return xp.where(valid_limits, push(array, None, axis), np.nan)

# The method parameter makes that the tests for python 3.7 fails.
return chunkmanager.scan(
func=bottleneck.push,
binop=_fill_with_last_one,
ident=np.nan,
arr=array,
axis=axis,
dtype=array.dtype,
)


def push(array, n, axis):
from bottleneck import push

if is_duck_dask_array(array):
return dask_array_ops.push(array, n, axis)
if is_chunked_array(array):
return _chunked_push(array, n, axis)
else:
return push(array, n, axis)

Expand Down
37 changes: 37 additions & 0 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,43 @@ def reduction(
"""
raise NotImplementedError()

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> T_ChunkedArray:
"""
General version of a 1D scan, also known as a cumulative array reduction.

Used in ``ffill`` and ``bfill` in xarray.

Parameters
----------
func: callable
Cumulative function like np.cumsum or np.cumprod
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
ident: Number
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
arr: dask Array
axis: int, optional
dtype: dtype

Returns
-------
Chunked array

See also
--------
dask.array.cumreduction
"""
raise NotImplementedError()

@abstractmethod
def apply_gufunc(
self,
Expand Down