Skip to content

Commit 41653c8

Browse files
committed
Recreate @gajomi's pydata#2070 to keep attrs when calling astype()
1 parent d9ebcaf commit 41653c8

6 files changed

Lines changed: 91 additions & 1 deletion

File tree

xarray/core/common.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,46 @@ def isin(self, test_elements):
12991299
dask="allowed",
13001300
)
13011301

1302+
def astype(self, dtype, casting="unsafe", copy=True):
1303+
"""
1304+
Copy of the xarray object, with data cast to a specified type.
1305+
Leaves coordinate dtype unchanged.
1306+
1307+
Parameters
1308+
----------
1309+
dtype : str or dtype
1310+
Typecode or data-type to which the array is cast.
1311+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1312+
Controls what kind of data casting may occur. Defaults to 'unsafe'
1313+
for backwards compatibility.
1314+
1315+
* 'no' means the data types should not be cast at all.
1316+
* 'equiv' means only byte-order changes are allowed.
1317+
* 'safe' means only casts which can preserve values are allowed.
1318+
* 'same_kind' means only safe casts or casts within a kind,
1319+
like float64 to float32, are allowed.
1320+
* 'unsafe' means any data conversions may be done.
1321+
copy : bool, optional
1322+
By default, astype always returns a newly allocated array. If this
1323+
is set to False and the `dtype` requirement is satisfied, the input
1324+
array is returned instead of a copy.
1325+
1326+
See also
1327+
--------
1328+
np.ndarray.astype
1329+
dask.array.Array.astype
1330+
"""
1331+
from .computation import apply_ufunc
1332+
1333+
return apply_ufunc(
1334+
duck_array_ops.astype,
1335+
self,
1336+
dtype,
1337+
keep_attrs=True,
1338+
kwargs={"casting": casting, "copy": copy},
1339+
dask="allowed",
1340+
)
1341+
13021342
def __enter__(self: T) -> T:
13031343
return self
13041344

xarray/core/duck_array_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def trapz(y, x, axis):
150150
)
151151

152152

153+
def astype(data, dtype, **kwargs):
154+
return data.astype(dtype, **kwargs)
155+
156+
153157
def asarray(data, xp=np):
154158
return (
155159
data

xarray/core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
NUMPY_SAME_METHODS = ["item", "searchsorted"]
4343
# methods which don't modify the data shape, so the result should still be
4444
# wrapped in an Variable/DataArray
45-
NUMPY_UNARY_METHODS = ["astype", "argsort", "clip", "conj", "conjugate"]
45+
NUMPY_UNARY_METHODS = ["argsort", "clip", "conj", "conjugate"]
4646
PANDAS_UNARY_FUNCTIONS = ["isnull", "notnull"]
4747
# methods which remove an axis
4848
REDUCE_METHODS = ["all", "any"]

xarray/core/variable.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,37 @@ def data(self, data):
360360
)
361361
self._data = data
362362

363+
def astype(self, dtype, casting="unsafe", copy=True):
364+
"""
365+
Copy of the Variable object, with data cast to a specified type.
366+
367+
Parameters
368+
----------
369+
dtype : str or dtype
370+
Typecode or data-type to which the array is cast.
371+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
372+
Controls what kind of data casting may occur. Defaults to 'unsafe'
373+
for backwards compatibility.
374+
375+
* 'no' means the data types should not be cast at all.
376+
* 'equiv' means only byte-order changes are allowed.
377+
* 'safe' means only casts which can preserve values are allowed.
378+
* 'same_kind' means only safe casts or casts within a kind,
379+
like float64 to float32, are allowed.
380+
* 'unsafe' means any data conversions may be done.
381+
copy : bool, optional
382+
By default, astype always returns a newly allocated array. If this
383+
is set to False and the `dtype` requirement is satisfied, the input
384+
array is returned instead of a copy.
385+
386+
See also
387+
--------
388+
np.ndarray.astype
389+
dask.array.Array.astype
390+
"""
391+
self.data = duck_array_ops.astype(self.data, dtype, casting=casting, copy=copy)
392+
return self
393+
363394
def load(self, **kwargs):
364395
"""Manually trigger loading of this variable's data from disk or a
365396
remote source into memory and return this variable.

xarray/tests/test_dataarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,6 +1874,13 @@ def test_array_interface(self):
18741874
bar = Variable(["x", "y"], np.zeros((10, 20)))
18751875
assert_equal(self.dv, np.maximum(self.dv, bar))
18761876

1877+
def test_astype_attrs(self):
1878+
mda1 = self.mda.copy()
1879+
mda1.attrs["foo"] = "bar"
1880+
mda2 = mda1.astype(bool)
1881+
1882+
assert list(mda1.attrs.items()) == list(mda2.attrs.items())
1883+
18771884
def test_is_null(self):
18781885
x = np.random.RandomState(42).randn(5, 6)
18791886
x[x < 0] = np.nan

xarray/tests/test_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5634,6 +5634,14 @@ def test_pad(self):
56345634
np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42)
56355635
np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan)
56365636

5637+
def test_astype_attrs(self):
5638+
data = create_test_data(seed=123)
5639+
data.attrs["foo"] = "bar"
5640+
databool = data.astype(bool)
5641+
5642+
assert list(data.attrs.items()) == list(databool.attrs.items())
5643+
assert list(data.var1.attrs.items()) == list(databool.var1.attrs.items())
5644+
56375645

56385646
# Py.test tests
56395647

0 commit comments

Comments
 (0)