diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 98ff9002856..c84259050d6 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -5,25 +5,15 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): """Wrapper to apply bottleneck moving window funcs on dask arrays""" - import dask.array as da - - dtype, fill_value = dtypes.maybe_promote(a.dtype) - a = a.astype(dtype) - # inputs for overlap - if axis < 0: - axis = a.ndim + axis - depth = {d: 0 for d in range(a.ndim)} - depth[axis] = (window + 1) // 2 - boundary = {d: fill_value for d in range(a.ndim)} - # Create overlap array. - ag = da.overlap.overlap(a, depth=depth, boundary=boundary) - # apply rolling func - out = da.map_blocks( - moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype + dtype, _ = dtypes.maybe_promote(a.dtype) + return a.data.map_overlap( + moving_func, + depth={axis: (window - 1, 0)}, + axis=axis, + dtype=dtype, + window=window, + min_count=min_count, ) - # trim array - result = da.overlap.trim_internal(out, depth) - return result def least_squares(lhs, rhs, rcond=None, skipna=False): diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 93fdda52e8c..781550207ff 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -10,7 +10,7 @@ import numpy as np from packaging.version import Version -from xarray.core import dtypes, duck_array_ops, utils +from xarray.core import dask_array_ops, dtypes, duck_array_ops, utils from xarray.core.arithmetic import CoarsenArithmetic from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray @@ -597,16 +597,18 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") if is_duck_dask_array(padded.data): - raise AssertionError("should not be reachable") + values = dask_array_ops.dask_rolling_wrapper( + func, padded, axis=axis, window=self.window[0], min_count=min_count + ) else: values = func( padded.data, window=self.window[0], min_count=min_count, axis=axis ) - # index 0 is at the rightmost edge of the window - # need to reverse index here - # see GH #8541 - if func in [bottleneck.move_argmin, bottleneck.move_argmax]: - values = self.window[0] - 1 - values + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func in [bottleneck.move_argmin, bottleneck.move_argmax]: + values = self.window[0] - 1 - values if self.center[0]: values = values[valid] @@ -669,12 +671,12 @@ def _array_reduce( if ( OPTIONS["use_bottleneck"] and bottleneck_move_func is not None - and not is_duck_dask_array(self.obj.data) + and ( + not is_duck_dask_array(self.obj.data) + or module_available("dask", "2024.11.0") + ) and self.ndim == 1 ): - # TODO: re-enable bottleneck with dask after the issues - # underlying https://github.com/pydata/xarray/issues/2940 are - # fixed. return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs ) diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 9d880969a82..57bf08b48a7 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -107,12 +107,13 @@ def test_rolling_properties(self, da) -> None: ): da.rolling(foo=2) + @requires_dask @pytest.mark.parametrize( "name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax") ) @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) - @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize("backend", ["numpy", "dask"], indirect=True) def test_rolling_wrapped_bottleneck( self, da, name, center, min_periods, compute_backend ) -> None: