Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Bug fixes

- Numbers are properly formatted in a plot's title (:issue:`5788`, :pull:`5789`).
By `Maxime Liquet <https://github.com/maximlt>`_.
- Subclasses of ``byte`` and ``str`` (e.g. ``np.str_`` and ``np.bytes_``) will now serialise to disk rather than raising a ``ValueError: unsupported dtype for netCDF4 variable: object`` as they did previously (:pull:`5264`).
By `Zeb Nicholls <https://github.com/znicholls>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@


def create_vlen_dtype(element_type):
if element_type not in (str, bytes):
raise TypeError("unsupported type for vlen_dtype: {!r}".format(element_type))
# based on h5py.special_dtype
return np.dtype("O", metadata={"element_type": element_type})

Expand Down
8 changes: 6 additions & 2 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,12 @@ def _infer_dtype(array, name=None):
return np.dtype(float)

element = array[(0,) * array.ndim]
if isinstance(element, (bytes, str)):
return strings.create_vlen_dtype(type(element))
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
if isinstance(element, bytes):
return strings.create_vlen_dtype(bytes)
elif isinstance(element, str):
return strings.create_vlen_dtype(str)

dtype = np.array(element).dtype
if dtype.kind != "O":
Expand Down
36 changes: 35 additions & 1 deletion xarray/tests/test_coding_strings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from contextlib import suppress

import numpy as np
import pandas as pd
import pytest

from xarray import Variable
from xarray.coding import strings
from xarray.core import indexing

from . import IndexerMaker, assert_array_equal, assert_identical, requires_dask
from . import (
IndexerMaker,
assert_array_equal,
assert_identical,
requires_dask,
requires_netCDF4,
)

with suppress(ImportError):
import dask.array as da
Expand All @@ -29,6 +36,33 @@ def test_vlen_dtype() -> None:
assert strings.check_vlen_dtype(np.dtype(object)) is None


@pytest.mark.parametrize("numpy_str_type", (np.str_, np.bytes_))
def test_numpy_subclass_handling(numpy_str_type) -> None:
with pytest.raises(TypeError, match="unsupported type for vlen_dtype"):
strings.create_vlen_dtype(numpy_str_type)


@requires_netCDF4
@pytest.mark.parametrize("str_type", (str, np.str_))
def test_write_file_from_np_str(str_type) -> None:
# should be moved elsewhere probably
scenarios = [str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]]
years = range(2015, 2100 + 1)
tdf = pd.DataFrame(
data=np.random.random((len(scenarios), len(years))),
columns=years,
index=scenarios,
)
tdf.index.name = "scenario"
tdf.columns.name = "year"
tdf = tdf.stack()
tdf.name = "tas"

txr = tdf.to_xarray()

txr.to_netcdf("test.nc")


def test_EncodedStringCoder_decode() -> None:
coder = strings.EncodedStringCoder()

Expand Down