Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions xarray/core/_typed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ def round(self, *args, **kwargs):
def argsort(self, *args, **kwargs):
return self._unary_op(ops.argsort, *args, **kwargs)

def clip(self, *args, **kwargs):
return self._unary_op(ops.clip, *args, **kwargs)

def conj(self, *args, **kwargs):
return self._unary_op(ops.conj, *args, **kwargs)

Expand Down Expand Up @@ -195,7 +192,6 @@ def conjugate(self, *args, **kwargs):
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
clip.__doc__ = ops.clip.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__

Expand Down Expand Up @@ -338,9 +334,6 @@ def round(self, *args, **kwargs):
def argsort(self, *args, **kwargs):
return self._unary_op(ops.argsort, *args, **kwargs)

def clip(self, *args, **kwargs):
return self._unary_op(ops.clip, *args, **kwargs)

def conj(self, *args, **kwargs):
return self._unary_op(ops.conj, *args, **kwargs)

Expand Down Expand Up @@ -389,7 +382,6 @@ def conjugate(self, *args, **kwargs):
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
clip.__doc__ = ops.clip.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__

Expand Down Expand Up @@ -532,9 +524,6 @@ def round(self, *args, **kwargs):
def argsort(self, *args, **kwargs):
return self._unary_op(ops.argsort, *args, **kwargs)

def clip(self, *args, **kwargs):
return self._unary_op(ops.clip, *args, **kwargs)

def conj(self, *args, **kwargs):
return self._unary_op(ops.conj, *args, **kwargs)

Expand Down Expand Up @@ -583,7 +572,6 @@ def conjugate(self, *args, **kwargs):
__invert__.__doc__ = operator.invert.__doc__
round.__doc__ = ops.round_.__doc__
argsort.__doc__ = ops.argsort.__doc__
clip.__doc__ = ops.clip.__doc__
conj.__doc__ = ops.conj.__doc__
conjugate.__doc__ = ops.conjugate.__doc__

Expand Down
3 changes: 0 additions & 3 deletions xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class DatasetOpsMixin:
def __invert__(self: T_Dataset) -> T_Dataset: ...
def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def clip(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...

Expand Down Expand Up @@ -235,7 +234,6 @@ class DataArrayOpsMixin:
def __invert__(self: T_DataArray) -> T_DataArray: ...
def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def clip(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...

Expand Down Expand Up @@ -406,7 +404,6 @@ class VariableOpsMixin:
def __invert__(self: T_Variable) -> T_Variable: ...
def round(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def clip(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ...
def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ...

Expand Down
12 changes: 12 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,18 @@ def squeeze(
dims = get_squeeze_dims(self, dim, axis)
return self.isel(drop=drop, **{d: 0 for d in dims})

def clip(self, min=None, max=None, *, keep_attrs=None):
from .computation import apply_ufunc

if keep_attrs is None:
# When this was a unary func, the default was True, so retaining the
# default.
keep_attrs = _get_keep_attrs(default=True)

return apply_ufunc(
np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
)

def get_index(self, key: Hashable) -> pd.Index:
"""Get an index for a dimension, with fall-back to a default RangeIndex"""
if key not in self.dims:
Expand Down
1 change: 0 additions & 1 deletion xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def inplace_to_noninplace_op(f):

# _typed_ops.py uses the following wrapped functions as a kind of unary operator
argsort = _method_wrapper("argsort")
clip = _method_wrapper("clip")
conj = _method_wrapper("conj")
conjugate = _method_wrapper("conjugate")
round_ = _func_slash_method_wrapper(duck_array_ops.around, name="round")
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,11 @@ def fillna(self, value):
def where(self, cond, other=dtypes.NA):
return ops.where_method(self, cond, other)

def clip(self, min=None, max=None):
Copy link
Collaborator Author

@max-sixty max-sixty Apr 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether there should be a class that DataArray & Variable should inherit from (but not AbstractArray) where they share methods.

I don't think there's enough overlap to justify it at the moment, but worth considering. That may mean we could unify the tests, which are currently highly duplicated for clip for example.

from .computation import apply_ufunc

return apply_ufunc(np.clip, self, min, max)

def reduce(
self,
func,
Expand Down
67 changes: 54 additions & 13 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6434,28 +6434,43 @@ def test_idxminmax_dask(self, op, ndim):
assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x"))


@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)])
def backend(request):
return request.param


@pytest.fixture(params=[1])
def da(request):
def da(request, backend):
if request.param == 1:
times = pd.date_range("2000-01-01", freq="1D", periods=21)
values = np.random.random((3, 21, 4))
da = DataArray(values, dims=("a", "time", "x"))
da["time"] = times
return da
da = DataArray(
np.random.random((3, 21, 4)),
dims=("a", "time", "x"),
coords=dict(time=times),
)

if request.param == 2:
return DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")

if request.param == "repeating_ints":
return DataArray(
da = DataArray(
np.tile(np.arange(12), 5).reshape(5, 4, 3),
coords={"x": list("abc"), "y": list("defg")},
dims=list("zyx"),
)

if backend == "dask":
return da.chunk()
elif backend == "numpy":
return da
else:
raise ValueError


@pytest.fixture
def da_dask(seed=123):
# TODO: if possible, use the `da` fixture parameterized with backends rather than
# this.
pytest.importorskip("dask.array")
rs = np.random.RandomState(seed)
times = pd.date_range("2000-01-01", freq="1D", periods=21)
Expand Down Expand Up @@ -6596,6 +6611,7 @@ def test_rolling_properties(da):
@pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median"))
@pytest.mark.parametrize("center", (True, False, None))
@pytest.mark.parametrize("min_periods", (1, None))
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_wrapped_bottleneck(da, name, center, min_periods):
bn = pytest.importorskip("bottleneck", minversion="1.1")

Expand Down Expand Up @@ -6898,12 +6914,7 @@ def test_rolling_keep_attrs_deprecated():
data = np.linspace(10, 15, 100)
coords = np.linspace(1, 10, 100)

da = DataArray(
data,
dims=("coord"),
coords={"coord": coords},
attrs=attrs_da,
)
da = DataArray(data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da,)

# deprecated option
with pytest.warns(
Expand Down Expand Up @@ -7206,6 +7217,7 @@ def test_fallback_to_iris_AuxCoord(self, coord_values):
@pytest.mark.parametrize(
"window_type, window", [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]]
)
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp(da, dim, window_type, window):
da = da.isel(a=0)
da = da.where(da > 0.2)
Expand All @@ -7225,6 +7237,7 @@ def test_rolling_exp(da, dim, window_type, window):


@requires_numbagg
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp_keep_attrs(da):
attrs = {"attrs": "da"}
da.attrs = attrs
Expand Down Expand Up @@ -7313,3 +7326,31 @@ def test_deepcopy_obj_array():
x0 = DataArray(np.array([object()]))
x1 = deepcopy(x0)
assert x0.values[0] is not x1.values[0]


def test_clip(da):
result = da.clip(min=0.5)
assert result.min(...) >= 0.5

result = da.clip(max=0.5)
assert result.max(...) <= 0.5

result = da.clip(min=0.25, max=0.75)
assert result.min(...) >= 0.25
assert result.max(...) <= 0.75

result = da.clip(min=da.mean("x"), max=da.mean("a"))
assert result.dims == da.dims
assert_array_equal(
result.data,
np.clip(da.data, da.mean("x").data[:, :, np.newaxis], da.mean("a").data),
)

with_nans = da.isel(time=[0, 1]).reindex_like(da)
result = da.clip(with_nans)
# The values should be the same where there were NaNs.
assert_array_equal(result.isel(time=[0, 1]), with_nans.isel(time=[0, 1]))

# Unclear whether we want this work, OK to adjust the test when we have decided.
with pytest.raises(ValueError, match="arguments without labels along dimension"):
result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1]))
35 changes: 25 additions & 10 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6069,16 +6069,16 @@ def test_dir_unicode(data_set):
def ds(request):
if request.param == 1:
return Dataset(
{
"z1": (["y", "x"], np.random.randn(2, 8)),
"z2": (["time", "y"], np.random.randn(10, 2)),
},
{
"x": ("x", np.linspace(0, 1.0, 8)),
"time": ("time", np.linspace(0, 1.0, 10)),
"c": ("y", ["a", "b"]),
"y": range(2),
},
dict(
z1=(["y", "x"], np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
y=range(2),
),
)

if request.param == 2:
Expand Down Expand Up @@ -6845,3 +6845,18 @@ def test_deepcopy_obj_array():
x0 = Dataset(dict(foo=DataArray(np.array([object()]))))
x1 = deepcopy(x0)
assert x0["foo"].values[0] is not x1["foo"].values[0]


def test_clip(ds):
result = ds.clip(min=0.5)
assert result.min(...) >= 0.5

result = ds.clip(max=0.5)
assert result.max(...) <= 0.5

result = ds.clip(min=0.25, max=0.75)
assert result.min(...) >= 0.25
assert result.max(...) <= 0.75

result = ds.clip(min=ds.mean("y"), max=ds.mean("y"))
assert result.dims == ds.dims
29 changes: 29 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
]


@pytest.fixture
def var():
return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5))


class VariableSubclassobjects:
def test_properties(self):
data = 0.5 * np.arange(10)
Expand Down Expand Up @@ -2510,3 +2515,27 @@ def test_DaskIndexingAdapter(self):
v = Variable(dims=("x", "y"), data=CopyOnWriteArray(DaskIndexingAdapter(da)))
self.check_orthogonal_indexing(v)
self.check_vectorized_indexing(v)


def test_clip(var):
# Copied from test_dataarray (would there be a way to combine the tests?)
result = var.clip(min=0.5)
assert result.min(...) >= 0.5

result = var.clip(max=0.5)
assert result.max(...) <= 0.5

result = var.clip(min=0.25, max=0.75)
assert result.min(...) >= 0.25
assert result.max(...) <= 0.75

result = var.clip(min=var.mean("x"), max=var.mean("z"))
assert result.dims == var.dims
assert_array_equal(
result.data,
np.clip(
var.data,
var.mean("x").data[np.newaxis, :, :],
var.mean("z").data[:, :, np.newaxis],
),
)
1 change: 0 additions & 1 deletion xarray/util/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
OTHER_UNARY_METHODS = (
("round", "ops.round_"),
("argsort", "ops.argsort"),
("clip", "ops.clip"),
("conj", "ops.conj"),
("conjugate", "ops.conjugate"),
)
Expand Down