diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 19e52f3a3ed..de358d6d6b7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -73,6 +73,9 @@ Enhancements - 0d slices of ndarrays are now obtained directly through indexing, rather than extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel Wennberg `_. +- Added support for ``fill_value`` with + :py:meth:`~xarray.DataArray.shift` and :py:meth:`~xarray.Dataset.shift` + By `Maximilian Roos `_ Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e8531a62f4f..25a66e529ae 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6,7 +6,8 @@ import numpy as np import pandas as pd -from . import computation, groupby, indexing, ops, resample, rolling, utils +from . import ( + computation, dtypes, groupby, indexing, ops, resample, rolling, utils) from ..plot.plot import _PlotMethods from .accessors import DatetimeAccessor from .alignment import align, reindex_like_indexers @@ -2085,7 +2086,7 @@ def diff(self, dim, n=1, label='upper'): ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) return self._from_temp_dataset(ds) - def shift(self, shifts=None, **shifts_kwargs): + def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): """Shift this array by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. Values shifted from @@ -2098,6 +2099,8 @@ def shift(self, shifts=None, **shifts_kwargs): Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + fill_value: scalar, optional + Value to use for newly missing values **shifts_kwargs: The keyword arguments form of ``shifts``. One of shifts or shifts_kwarg must be provided. @@ -2122,8 +2125,9 @@ def shift(self, shifts=None, **shifts_kwargs): Coordinates: * x (x) int64 0 1 2 """ - ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs) - return self._from_temp_dataset(ds) + variable = self.variable.shift( + shifts=shifts, fill_value=fill_value, **shifts_kwargs) + return self._replace(variable=variable) def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this array by an offset along one or more dimensions. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7ac3b458232..62c6e98c954 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -13,8 +13,8 @@ import xarray as xr from . import ( - alignment, duck_array_ops, formatting, groupby, indexing, ops, pdcompat, - resample, rolling, utils) + alignment, dtypes, duck_array_ops, formatting, groupby, indexing, ops, + pdcompat, resample, rolling, utils) from ..coding.cftimeindex import _parse_array_of_cftime_strings from .alignment import align from .common import ( @@ -3476,7 +3476,7 @@ def diff(self, dim, n=1, label='upper'): else: return difference - def shift(self, shifts=None, **shifts_kwargs): + def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is @@ -3488,6 +3488,8 @@ def shift(self, shifts=None, **shifts_kwargs): Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + fill_value: scalar, optional + Value to use for newly missing values **shifts_kwargs: The keyword arguments form of ``shifts``. One of shifts or shifts_kwarg must be provided. @@ -3522,9 +3524,10 @@ def shift(self, shifts=None, **shifts_kwargs): variables = OrderedDict() for name, var in iteritems(self.variables): if name in self.data_vars: - var_shifts = dict((k, v) for k, v in shifts.items() - if k in var.dims) - variables[name] = var.shift(**var_shifts) + var_shifts = {k: v for k, v in shifts.items() + if k in var.dims} + variables[name] = var.shift( + fill_value=fill_value, shifts=var_shifts) else: variables[name] = var diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 883dbb34dff..09b632e47a6 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -301,7 +301,7 @@ def wrapped_func(self, **kwargs): else: shift = (-self.window // 2) + 1 valid = (slice(None), ) * axis + (slice(-shift, None), ) - padded = padded.pad_with_fill_value(**{self.dim: (0, -shift)}) + padded = padded.pad_with_fill_value({self.dim: (0, -shift)}) if isinstance(padded.data, dask_array_type): values = dask_rolling_wrapper(func, padded, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index cabab259446..243487db034 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -933,7 +933,7 @@ def squeeze(self, dim=None): dims = common.get_squeeze_dims(self, dim) return self.isel({d: 0 for d in dims}) - def _shift_one_dim(self, dim, count): + def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): axis = self.get_axis_num(dim) if count > 0: @@ -944,7 +944,11 @@ def _shift_one_dim(self, dim, count): keep = slice(None) trimmed_data = self[(slice(None),) * axis + (keep,)].data - dtype, fill_value = dtypes.maybe_promote(self.dtype) + + if fill_value is dtypes.NA: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = self.dtype shape = list(self.shape) shape[axis] = min(abs(count), shape[axis]) @@ -956,12 +960,12 @@ def _shift_one_dim(self, dim, count): else: full = np.full - nans = full(shape, fill_value, dtype=dtype) + filler = full(shape, fill_value, dtype=dtype) if count > 0: - arrays = [nans, trimmed_data] + arrays = [filler, trimmed_data] else: - arrays = [trimmed_data, nans] + arrays = [trimmed_data, filler] data = duck_array_ops.concatenate(arrays, axis) @@ -973,7 +977,7 @@ def _shift_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def shift(self, shifts=None, **shifts_kwargs): + def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): """ Return a new Variable with shifted data. @@ -983,6 +987,8 @@ def shift(self, shifts=None, **shifts_kwargs): Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + fill_value: scalar, optional + Value to use for newly missing values **shifts_kwargs: The keyword arguments form of ``shifts``. One of shifts or shifts_kwarg must be provided. @@ -995,7 +1001,7 @@ def shift(self, shifts=None, **shifts_kwargs): shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') result = self for dim, count in shifts.items(): - result = result._shift_one_dim(dim, count) + result = result._shift_one_dim(dim, count, fill_value=fill_value) return result def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5468905a320..53f53574031 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -14,6 +14,7 @@ DataArray, Dataset, IndexVariable, Variable, align, broadcast) from xarray.coding.times import CFDatetimeCoder, _import_cftime from xarray.convert import from_cdms2 +from xarray.core import dtypes from xarray.core.common import ALL_DIMS, full_like from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( @@ -3128,12 +3129,19 @@ def test_coordinate_diff(self): actual = lon.diff('lon') assert_equal(expected, actual) - @pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5]) - def test_shift(self, offset): + @pytest.mark.parametrize('offset', [-5, 0, 1, 2]) + @pytest.mark.parametrize('fill_value, dtype', + [(2, int), (dtypes.NA, float)]) + def test_shift(self, offset, fill_value, dtype): arr = DataArray([1, 2, 3], dims='x') - actual = arr.shift(x=1) - expected = DataArray([np.nan, 1, 2], dims='x') - assert_identical(expected, actual) + actual = arr.shift(x=1, fill_value=fill_value) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = DataArray([fill_value, 1, 2], dims='x') + assert_identical(expected, actual) + assert actual.dtype == dtype arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])]) expected = DataArray(arr.to_pandas().shift(offset)) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 521f2395758..6f6287efcac 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -15,7 +15,7 @@ from xarray import ( ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align, backends, broadcast, open_dataset, set_options) -from xarray.core import indexing, npcompat, utils +from xarray.core import dtypes, indexing, npcompat, utils from xarray.core.common import full_like from xarray.core.pycompat import ( OrderedDict, integer_types, iteritems, unicode_type) @@ -3917,12 +3917,17 @@ def test_dataset_diff_exception_label_str(self): with raises_regex(ValueError, '\'label\' argument has to'): ds.diff('dim2', label='raise_me') - def test_shift(self): + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_shift(self, fill_value): coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} attrs = {'meta': 'data'} ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) - actual = ds.shift(x=1) - expected = Dataset({'foo': ('x', [np.nan, 1, 2])}, coords, attrs) + actual = ds.shift(x=1, fill_value=fill_value) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value = np.nan + expected = Dataset({'foo': ('x', [fill_value, 1, 2])}, coords, attrs) assert_identical(expected, actual) with raises_regex(ValueError, 'dimensions'): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 91bc0e555c0..08cab4b3541 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -13,7 +13,7 @@ import pytz from xarray import Coordinate, Dataset, IndexVariable, Variable, set_options -from xarray.core import indexing +from xarray.core import dtypes, indexing from xarray.core.common import full_like, ones_like, zeros_like from xarray.core.indexing import ( BasicIndexer, CopyOnWriteArray, DaskIndexingAdapter, @@ -1179,24 +1179,32 @@ def test_indexing_0d_unicode(self): expected = Variable((), u'tmax') assert_identical(actual, expected) - def test_shift(self): + @pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0]) + def test_shift(self, fill_value): v = Variable('x', [1, 2, 3, 4, 5]) assert_identical(v, v.shift(x=0)) assert v is not v.shift(x=0) - expected = Variable('x', [np.nan, 1, 2, 3, 4]) - assert_identical(expected, v.shift(x=1)) - expected = Variable('x', [np.nan, np.nan, 1, 2, 3]) assert_identical(expected, v.shift(x=2)) - expected = Variable('x', [2, 3, 4, 5, np.nan]) - assert_identical(expected, v.shift(x=-1)) + if fill_value == dtypes.NA: + # if we supply the default, we expect the missing value for a + # float array + fill_value_exp = np.nan + else: + fill_value_exp = fill_value + + expected = Variable('x', [fill_value_exp, 1, 2, 3, 4]) + assert_identical(expected, v.shift(x=1, fill_value=fill_value)) + + expected = Variable('x', [2, 3, 4, 5, fill_value_exp]) + assert_identical(expected, v.shift(x=-1, fill_value=fill_value)) - expected = Variable('x', [np.nan] * 5) - assert_identical(expected, v.shift(x=5)) - assert_identical(expected, v.shift(x=6)) + expected = Variable('x', [fill_value_exp] * 5) + assert_identical(expected, v.shift(x=5, fill_value=fill_value)) + assert_identical(expected, v.shift(x=6, fill_value=fill_value)) with raises_regex(ValueError, 'dimension'): v.shift(z=0) @@ -1204,8 +1212,8 @@ def test_shift(self): v = Variable('x', [1, 2, 3, 4, 5], {'foo': 'bar'}) assert_identical(v, v.shift(x=0)) - expected = Variable('x', [np.nan, 1, 2, 3, 4], {'foo': 'bar'}) - assert_identical(expected, v.shift(x=1)) + expected = Variable('x', [fill_value_exp, 1, 2, 3, 4], {'foo': 'bar'}) + assert_identical(expected, v.shift(x=1, fill_value=fill_value)) def test_shift2d(self): v = Variable(('x', 'y'), [[1, 2], [3, 4]])