Skip to content
16 changes: 10 additions & 6 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Enhancements
- :py:class:`CFTimeIndex` uses slicing for string indexing when possible (like
:py:class:`pandas.DatetimeIndex`), which avoids unnecessary copies.
By `Stephan Hoyer <https://github.com/shoyer>`_
- Enable passing ``rasterio.io.DatasetReader`` or ``rasterio.vrt.WarpedVRT`` to
``open_rasterio`` instead of file path string. Allows for in-memory
reprojection, see (:issue:`2588`).
By `Scott Henderson <https://github.com/scottyhq>`_.

Bug fixes
~~~~~~~~~
Expand All @@ -56,15 +60,15 @@ Breaking changes
- ``Dataset.T`` has been removed as a shortcut for :py:meth:`Dataset.transpose`.
Call :py:meth:`Dataset.transpose` directly instead.
- Iterating over a ``Dataset`` now includes only data variables, not coordinates.
Similarily, calling ``len`` and ``bool`` on a ``Dataset`` now
Similarily, calling ``len`` and ``bool`` on a ``Dataset`` now
includes only data variables.
- ``DataArray.__contains__`` (used by Python's ``in`` operator) now checks
array data, not coordinates.
array data, not coordinates.
- The old resample syntax from before xarray 0.10, e.g.,
``data.resample('1D', dim='time', how='mean')``, is no longer supported will
raise an error in most cases. You need to use the new resample syntax
instead, e.g., ``data.resample(time='1D').mean()`` or
``data.resample({'time': '1D'}).mean()``.
``data.resample({'time': '1D'}).mean()``.


- New deprecations (behavior will be changed in xarray 0.12):
Expand Down Expand Up @@ -101,13 +105,13 @@ Breaking changes
than by default trying to coerce them into ``np.datetime64[ns]`` objects.
A :py:class:`~xarray.CFTimeIndex` will be used for indexing along time
coordinates in these cases.
- A new method :py:meth:`~xarray.CFTimeIndex.to_datetimeindex` has been added
- A new method :py:meth:`~xarray.CFTimeIndex.to_datetimeindex` has been added
to aid in converting from a :py:class:`~xarray.CFTimeIndex` to a
:py:class:`pandas.DatetimeIndex` for the remaining use-cases where
using a :py:class:`~xarray.CFTimeIndex` is still a limitation (e.g. for
resample or plotting).
- Setting the ``enable_cftimeindex`` option is now a no-op and emits a
``FutureWarning``.
``FutureWarning``.

Enhancements
~~~~~~~~~~~~
Expand Down Expand Up @@ -194,7 +198,7 @@ Bug fixes
the dates must be encoded using cftime rather than NumPy (:issue:`2272`).
By `Spencer Clark <https://github.com/spencerkclark>`_.

- Chunked datasets can now roundtrip to Zarr storage continually
- Chunked datasets can now roundtrip to Zarr storage continually
with `to_zarr` and ``open_zarr`` (:issue:`2300`).
By `Lily Wang <https://github.com/lilyminium>`_.

Expand Down
72 changes: 64 additions & 8 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
from collections import OrderedDict
from distutils.version import LooseVersion

import numpy as np

from .. import DataArray
Expand All @@ -24,11 +23,13 @@
class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""

def __init__(self, manager):
def __init__(self, manager, vrt=None):
self.manager = manager

# cannot save riods as an attribute: this would break pickleability
riods = manager.acquire()
if vrt:
riods = vrt

self._shape = (riods.count, riods.height, riods.width)

Expand Down Expand Up @@ -123,6 +124,42 @@ def __getitem__(self, key):
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem)


class RasterioVRTWrapper(RasterioArrayWrapper):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding a subclass with a lot of duplicated logic, could you add this into the base RasterioArrayWrapper class?

Something like:

def __init__(self, manager, vrt_params=None):
    ...
    riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params)

Copy link
Contributor Author

@scottyhq scottyhq Dec 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, just moved the changes to RasterioArrayWrapper

"""A wrapper around rasterio WarpedVRT objects"""

def __init__(self, manager, vrt_params):
from rasterio.vrt import WarpedVRT
self.manager = manager
self.vrt_params = vrt_params
# cannot save riods as an attribute: this would break pickleability
riods = manager.acquire()
vrt = WarpedVRT(riods, **vrt_params)
self._shape = (vrt.count, vrt.height, vrt.width)

dtypes = vrt.dtypes
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError('All bands should have the same dtype')
self._dtype = np.dtype(dtypes[0])

def _getitem(self, key):
from rasterio.vrt import WarpedVRT
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

if not band_key or any(start == stop for (start, stop) in window):
# no need to do IO
shape = (len(band_key),) + tuple(
stop - start for (start, stop) in window)
out = np.zeros(shape, dtype=self.dtype)
else:
riods = self.manager.acquire()
vrt = WarpedVRT(riods, **self.vrt_params)
out = vrt.read(band_key, window=window)

if squeeze_axis:
out = np.squeeze(out, axis=squeeze_axis)
return out[np_inds]


def _parse_envi(meta):
"""Parse ENVI metadata into Python data structures.

Expand Down Expand Up @@ -176,8 +213,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,

Parameters
----------
filename : str
Path to the file to open.
filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT
Path to the file to open. Or already open rasterio dataset.
parse_coordinates : bool, optional
Whether to parse the x and y coordinates out of the file's
``transform`` attribute or not. The default is to automatically
Expand All @@ -204,12 +241,27 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
data : DataArray
The newly created DataArray.
"""

import rasterio

vrt_params = None
if isinstance(filename, rasterio.io.DatasetReader):
filename = filename.name
elif isinstance(filename, rasterio.vrt.WarpedVRT):
vrt = filename
filename = vrt.src_dataset.name
vrt_params = dict(crs=vrt.crs.to_string(),
resampling=vrt.resampling,
src_nodata=vrt.src_nodata,
dst_nodata=vrt.dst_nodata,
tolerance=vrt.tolerance,
warp_extras=vrt.warp_extras)

manager = CachingFileManager(rasterio.open, filename, mode='r')
riods = manager.acquire()

if vrt_params:
riods = rasterio.vrt.WarpedVRT(riods, **vrt_params)

if cache is None:
cache = chunks is None

Expand Down Expand Up @@ -282,13 +334,17 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if (isinstance(v, (list, np.ndarray)) and
len(v) == riods.count):
if (isinstance(v, (list, np.ndarray))
and len(v) == riods.count):
coords[k] = ('band', np.asarray(v))
else:
attrs[k] = v

data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager))
if vrt_params:
data = indexing.LazilyOuterIndexedArray(RasterioVRTWrapper(manager,
vrt_params))
else:
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager))

# this lets you write arrays loaded with rasterio
data = indexing.CopyOnWriteArray(data)
Expand Down
113 changes: 90 additions & 23 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def check_dtypes_roundtripped(self, expected, actual):
actual_dtype = actual.variables[k].dtype
# TODO: check expected behavior for string dtypes more carefully
string_kinds = {'O', 'S', 'U'}
assert (expected_dtype == actual_dtype or
(expected_dtype.kind in string_kinds and
actual_dtype.kind in string_kinds))
assert (expected_dtype == actual_dtype
or (expected_dtype.kind in string_kinds and
actual_dtype.kind in string_kinds))

def test_roundtrip_test_data(self):
expected = create_test_data()
Expand Down Expand Up @@ -376,17 +376,17 @@ def test_roundtrip_cftime_datetime_data(self):
with self.roundtrip(expected, save_kwargs=kwds) as actual:
abs_diff = abs(actual.t.values - expected_decoded_t)
assert (abs_diff <= np.timedelta64(1, 's')).all()
assert (actual.t.encoding['units'] ==
'days since 0001-01-01 00:00:00.000000')
assert (actual.t.encoding['calendar'] ==
expected_calendar)
assert (actual.t.encoding['units']
== 'days since 0001-01-01 00:00:00.000000')
assert (actual.t.encoding['calendar']
== expected_calendar)

abs_diff = abs(actual.t0.values - expected_decoded_t0)
assert (abs_diff <= np.timedelta64(1, 's')).all()
assert (actual.t0.encoding['units'] ==
'days since 0001-01-01')
assert (actual.t.encoding['calendar'] ==
expected_calendar)
assert (actual.t0.encoding['units']
== 'days since 0001-01-01')
assert (actual.t.encoding['calendar']
== expected_calendar)

def test_roundtrip_timedelta_data(self):
time_deltas = pd.to_timedelta(['1h', '2h', 'NaT'])
Expand Down Expand Up @@ -622,20 +622,20 @@ def test_unsigned_roundtrip_mask_and_scale(self):
encoded = create_encoded_unsigned_masked_scaled_data()
with self.roundtrip(decoded) as actual:
for k in decoded.variables:
assert (decoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (decoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(decoded, actual, decode_bytes=False)
with self.roundtrip(decoded,
open_kwargs=dict(decode_cf=False)) as actual:
for k in encoded.variables:
assert (encoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (encoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)
with self.roundtrip(encoded,
open_kwargs=dict(decode_cf=False)) as actual:
for k in encoded.variables:
assert (encoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (encoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)
# make sure roundtrip encoding didn't change the
# original dataset.
Expand Down Expand Up @@ -2520,8 +2520,8 @@ def myatts(**attrs):
'ULOD_FLAG': '-7777', 'ULOD_VALUE': 'N/A',
'LLOD_FLAG': '-8888',
'LLOD_VALUE': ('N/A, N/A, N/A, N/A, 0.025'),
'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/' +
'IcarttDataFormat.htm'),
'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/'
+ 'IcarttDataFormat.htm'),
'REVISION': 'R0',
'R0': 'No comments for this revision.',
'TFLAG': 'Start_UTC'
Expand Down Expand Up @@ -2610,8 +2610,8 @@ def test_uamiv_format_read(self):
expected = xr.Variable(('TSTEP',), data,
dict(bounds='time_bounds',
long_name=('synthesized time coordinate ' +
'from SDATE, STIME, STEP ' +
'global attributes')))
'from SDATE, STIME, STEP '
+ 'global attributes')))
actual = camxfile.variables['time']
assert_allclose(expected, actual)
camxfile.close()
Expand Down Expand Up @@ -2640,8 +2640,8 @@ def test_uamiv_format_mfread(self):
data = np.concatenate([data1] * 2, axis=0)
attrs = dict(bounds='time_bounds',
long_name=('synthesized time coordinate ' +
'from SDATE, STIME, STEP ' +
'global attributes'))
'from SDATE, STIME, STEP '
+ 'global attributes'))
expected = xr.Variable(('TSTEP',), data, attrs)
actual = camxfile.variables['time']
assert_allclose(expected, actual)
Expand Down Expand Up @@ -3057,6 +3057,73 @@ def test_http_url(self):
import dask.array as da
assert isinstance(actual.data, da.Array)

def test_rasterio_environment(self):
import rasterio
with create_tmp_geotiff() as (tmp_file, expected):
# Should fail with error since suffix not allowed
with pytest.raises(Exception):
with rasterio.Env(GDAL_SKIP='GTiff'):
with xr.open_rasterio(tmp_file) as actual:
assert_allclose(actual, expected)

def test_rasterio_vrt(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For network tests we have a special decorator (@network), see https://github.com/pydata/xarray/blob/55f21deff4c2b42bd6ead4dbe26a1b123337913a/xarray/tests/test_tutorial.py (although that's the only use of it as it seems?)

import rasterio
# tmp_file default crs is UTM: CRS({'init': 'epsg:32618'}
with create_tmp_geotiff() as (tmp_file, expected):
with rasterio.open(tmp_file) as src:
with rasterio.vrt.WarpedVRT(src, crs='epsg:4326') as vrt:
expected_shape = (vrt.width, vrt.height)
expected_crs = vrt.crs
print(expected_crs)
expected_res = vrt.res
# Value of single pixel in center of image
lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
expected_val = next(vrt.sample([(lon, lat)]))
with xr.open_rasterio(vrt) as da:
actual_shape = (da.sizes['x'], da.sizes['y'])
actual_crs = da.crs
print(actual_crs)
actual_res = da.res
actual_val = da.sel(dict(x=lon, y=lat),
method='nearest').data

assert actual_crs == expected_crs
assert actual_res == expected_res
assert actual_shape == expected_shape
assert expected_val.all() == actual_val.all()

@network
def test_rasterio_vrt_network(self):
import rasterio

url = 'https://storage.googleapis.com/\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't even know this form of line-wrapping for strings was possible in Python :)

gcp-public-data-landsat/LC08/01/047/027/\
LC08_L1TP_047027_20130421_20170310_01_T1/\
LC08_L1TP_047027_20130421_20170310_01_T1_B4.TIF'
env = rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR',
CPL_VSIL_CURL_USE_HEAD=False,
CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF')
with env:
with rasterio.open(url) as src:
with rasterio.vrt.WarpedVRT(src, crs='epsg:4326') as vrt:
expected_shape = (vrt.width, vrt.height)
expected_crs = vrt.crs
expected_res = vrt.res
# Value of single pixel in center of image
lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
expected_val = next(vrt.sample([(lon, lat)]))
with xr.open_rasterio(vrt) as da:
actual_shape = (da.sizes['x'], da.sizes['y'])
actual_crs = da.crs
actual_res = da.res
actual_val = da.sel(dict(x=lon, y=lat),
method='nearest').data

assert_equal(actual_shape, expected_shape)
assert_equal(actual_crs, expected_crs)
assert_equal(actual_res, expected_res)
assert_equal(expected_val, actual_val)


class TestEncodingInvalid(object):

Expand Down