Skip to content

Commit 7fd3345

Browse files
committed
♻️ Precise selective tiling and interpolation using xarray slices
Towards more fine-grained cropping of our image tiles! Basically have data_prep.selective tile do the crop using exact geographic coordinate slice ranges instead of having to convert (sometimes imprecisely) to image-based coordinates. Uses xarray's subset 'sel'(ection) method which does away with the mess that is rasterio.windows and affine transformations. REMA tiles doesn't seem to require gapfilling anymore so we've temporarily disabled gapfilling (raise NotImplementedError) until it is needed for getting tiles for the whole of Antarctica again. Also using a nicer interpolation method in data_prep.selective_tile, especially relevant for W2_data aka the MEASURES Surface Ice Velocity which is resampled from 450m to 500m (since a8863e4). Still resampling billinearly, but interpolation at the cropped tile's edges take into account pixels beyond the border if available. I've actually inspected these new Ice Velocity tiles manually and they look awesome! Might help with the strange high-level checkerboard artifacts. Side effect is that interpolation runs slowly (mitigated somewhat by using dask), until we can vectorize the whole function properly.
1 parent c38f74a commit 7fd3345

File tree

2 files changed

+93
-126
lines changed

2 files changed

+93
-126
lines changed

data_prep.ipynb

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
"import xarray as xr\n",
5555
"import salem\n",
5656
"\n",
57+
"import dask\n",
58+
"import dask.diagnostics\n",
5759
"import geopandas as gpd\n",
5860
"import pygmt as gmt\n",
5961
"import IPython.display\n",
@@ -1200,69 +1202,50 @@
12001202
" [0., 0.]]]], dtype=float32)\n",
12011203
" >>> os.remove(\"/tmp/tmp_st.nc\")\n",
12021204
" \"\"\"\n",
1203-
" array_list = []\n",
12041205
"\n",
1205-
" with rasterio.open(filepath) as dataset:\n",
1206+
" # Convert list of bounding box tuples to nice rasterio.coords.BoundingBox class\n",
1207+
" window_bounds = [\n",
1208+
" rasterio.coords.BoundingBox(\n",
1209+
" left=x0 - padding, bottom=y0 - padding, right=x1 + padding, top=y1 + padding\n",
1210+
" )\n",
1211+
" for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax\n",
1212+
" ]\n",
1213+
"\n",
1214+
" with xr.open_rasterio(\n",
1215+
" filepath, chunks=None if out_shape is None else {}, cache=False\n",
1216+
" ) as dataset:\n",
12061217
" print(f\"Tiling: {filepath} ... \", end=\"\")\n",
1207-
" for window_bound in window_bounds:\n",
12081218
"\n",
1209-
" if padding > 0:\n",
1210-
" window_bound = (\n",
1211-
" window_bound[0] - padding, # minx\n",
1212-
" window_bound[1] - padding, # miny\n",
1213-
" window_bound[2] + padding, # maxx\n",
1214-
" window_bound[3] + padding, # maxy\n",
1219+
" # Subset dataset according to window bound (wb)\n",
1220+
" daarray_list = [\n",
1221+
" dataset.sel(y=slice(wb.top, wb.bottom), x=slice(wb.left, wb.right))\n",
1222+
" for wb in window_bounds\n",
1223+
" ]\n",
1224+
" # Bilinear interpolate to new shape if out_shape is set\n",
1225+
" if out_shape is not None:\n",
1226+
" daarray_list = [\n",
1227+
" dataset.interp(\n",
1228+
" y=np.linspace(da.y[0], da.y[-1], num=out_shape[0]),\n",
1229+
" x=np.linspace(da.x[0], da.x[-1], num=out_shape[1]),\n",
1230+
" method=\"linear\",\n",
12151231
" )\n",
1232+
" for da in daarray_list\n",
1233+
" ]\n",
1234+
" daarray_stack = dask.array.stack(seq=daarray_list)\n",
12161235
"\n",
1217-
" window = rasterio.windows.from_bounds(\n",
1218-
" *window_bound, transform=dataset.transform, precision=None\n",
1219-
" ).round_offsets()\n",
1220-
"\n",
1221-
" # Read the raster according to the crop window\n",
1222-
" array = dataset.read(\n",
1223-
" indexes=list(range(1, dataset.count + 1)),\n",
1224-
" masked=True,\n",
1225-
" window=window,\n",
1226-
" out_shape=out_shape,\n",
1227-
" )\n",
1228-
" assert array.ndim == 3 # check that we have shape like (1, height, width)\n",
1229-
" assert array.shape[0] == 1 # channel-first (assuming only 1 channel)\n",
1230-
" assert not 0 in array.shape # ensure no empty dimensions (invalid window)\n",
1231-
"\n",
1232-
" try:\n",
1233-
" assert not array.mask.any() # check that there are no NAN values\n",
1234-
" except AssertionError:\n",
1235-
" # Replace pixels from another raster if available, else raise error\n",
1236-
" if gapfill_raster_filepath is not None:\n",
1237-
" with rasterio.open(gapfill_raster_filepath) as dataset2:\n",
1238-
" window2 = rasterio.windows.from_bounds(\n",
1239-
" *window_bound, transform=dataset2.transform, precision=None\n",
1240-
" ).round_offsets()\n",
1241-
"\n",
1242-
" array2 = dataset2.read(\n",
1243-
" indexes=list(range(1, dataset2.count + 1)),\n",
1244-
" masked=True,\n",
1245-
" window=window2,\n",
1246-
" out_shape=array.shape[1:],\n",
1247-
" )\n",
1248-
"\n",
1249-
" np.copyto(\n",
1250-
" dst=array, src=array2, where=array.mask\n",
1251-
" ) # fill in gaps where mask is True\n",
1252-
"\n",
1253-
" # assert not array.mask.any() # ensure no NAN values after gapfill\n",
1254-
" else:\n",
1255-
" plt.imshow(array.data[0, :, :])\n",
1256-
" plt.show()\n",
1257-
" print(\n",
1258-
" f\"WARN: Tile has missing data, try passing in gapfill_raster_filepath\"\n",
1259-
" )\n",
1260-
"\n",
1261-
" # assert array.shape[1] == array.shape[2] # check that height==width\n",
1262-
" array_list.append(array.data.astype(dtype=np.float32))\n",
1236+
" assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)\n",
1237+
" assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)\n",
1238+
" assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)\n",
12631239
" print(\"done!\")\n",
12641240
"\n",
1265-
" return np.stack(arrays=array_list)"
1241+
" with dask.diagnostics.ProgressBar(minimum=5.0):\n",
1242+
" try:\n",
1243+
" out_tiles = daarray_stack.compute().astype(dtype=np.float32)\n",
1244+
" assert not np.isnan(out_tiles).any() # check that there are no NAN values\n",
1245+
" except AssertionError:\n",
1246+
" raise NotImplementedError(\"gapfilling on dask xarray not yet implemented\")\n",
1247+
" finally:\n",
1248+
" return out_tiles"
12661249
]
12671250
},
12681251
{
@@ -1370,7 +1353,7 @@
13701353
" filepath=\"misc/REMA_100m_dem.tif\",\n",
13711354
" window_bounds=window_bounds_concat,\n",
13721355
" padding=1000,\n",
1373-
" gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n",
1356+
" # gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n",
13741357
")\n",
13751358
"print(rema.shape, rema.dtype)"
13761359
]
@@ -1407,6 +1390,7 @@
14071390
"output_type": "stream",
14081391
"text": [
14091392
"Tiling: misc/MEaSUREs_IceFlowSpeed_450m.tif ... done!\n",
1393+
"[########################################] | 100% Completed | 22.1s\n",
14101394
"(2347, 1, 20, 20) float32\n"
14111395
]
14121396
}
@@ -1501,7 +1485,7 @@
15011485
"name": "stdin",
15021486
"output_type": "stream",
15031487
"text": [
1504-
"Enter the code from the webpage: eyJjb2RlIjogImVmOTRiMzMzLTZkNDItNDJkYi1hM2Y1LTQ4NGNmZjc4OTIzOSIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n"
1488+
"Enter the code from the webpage: eyJjb2RlIjogIjg4ODljZTY0LTA1ODMtNGIxYS04YjE2LTQ0MjFjZDViMTQxNCIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n"
15051489
]
15061490
}
15071491
],
@@ -1573,32 +1557,32 @@
15731557
"name": "stderr",
15741558
"output_type": "stream",
15751559
"text": [
1576-
"100%|█████████| 6.73G/6.74G [00:01<00:00, 143MB/s]"
1560+
" 96%|█████████| 6.47G/6.74G [00:01<04:17, 1.04MB/s] "
15771561
]
15781562
},
15791563
{
15801564
"name": "stdout",
15811565
"output_type": "stream",
15821566
"text": [
1567+
"Fragment 16ed97cce049cd2859a379964a8fa7575d9b871ec126d33c824542b126eab177 already uploaded; skipping.\n",
1568+
"Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n",
1569+
"Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n",
1570+
"Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n",
15831571
"Fragment 28e2ca7656d61b0bc7f8f8c1db41914023e0cab1634e0ee645f38a87d894b416 already uploaded; skipping.\n",
1572+
"Fragment 6ef3a2439a508de0919bd33a713976b5aa4895929a9d7981c09f722ce702e16a already uploaded; skipping.\n",
15841573
"Fragment 80c9fa41ccc69be1d2cd4a367d56168321d1079e7260a1996089810db25172f6 already uploaded; skipping.\n",
1585-
"Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n",
1586-
"Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n",
1587-
"Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n",
15881574
"Fragment ca9c41a8dd56097e40865d2e65c65d299c22fc17608ddb6c604c532a69936307 already uploaded; skipping.\n",
1589-
"Fragment 704bb2fafcc9a6411047f799030dde3b4c2fb14de2e8d1eccfe651dcc6a455bf already uploaded; skipping.\n",
1590-
"Fragment bb9e1e7a62187671e58009533d2e930265f6c0827925216d354b984e2d506996 already uploaded; skipping.\n",
1591-
"Fragment f750893861a1a268c8ffe0ba7db36c933223bbf5fcbb786ecef3f052b20f9b8a already uploaded; skipping.\n",
1592-
"Fragment c7d98ad4258130d8cdea7ec6c9fbb33a868e64d4a14a57955f759ba3d35180c4 already uploaded; skipping.\n",
1575+
"Fragment 04a52d9a52901d8f7f74fd9ef6fc9fc215d6c9d787540511f68630f5cca16094 already uploaded; skipping.\n",
15931576
"Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n",
1594-
"Fragment fae1c9c2308c944488a9bc4703518395f3056cbeb55fd11f0f114282eb8cdf32 already uploaded; skipping.\n"
1577+
"Fragment f750893861a1a268c8ffe0ba7db36c933223bbf5fcbb786ecef3f052b20f9b8a already uploaded; skipping.\n",
1578+
"Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n"
15951579
]
15961580
},
15971581
{
15981582
"name": "stderr",
15991583
"output_type": "stream",
16001584
"text": [
1601-
"100%|██████████| 6.74G/6.74G [00:03<00:00, 2.11GB/s]\n"
1585+
"100%|██████████| 6.74G/6.74G [00:03<00:00, 1.77GB/s]\n"
16021586
]
16031587
},
16041588
{

data_prep.py

Lines changed: 41 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import xarray as xr
3939
import salem
4040

41+
import dask
42+
import dask.diagnostics
4143
import geopandas as gpd
4244
import pygmt as gmt
4345
import IPython.display
@@ -610,69 +612,50 @@ def selective_tile(
610612
[0., 0.]]]], dtype=float32)
611613
>>> os.remove("/tmp/tmp_st.nc")
612614
"""
613-
array_list = []
614615

615-
with rasterio.open(filepath) as dataset:
616-
print(f"Tiling: {filepath} ... ", end="")
617-
for window_bound in window_bounds:
618-
619-
if padding > 0:
620-
window_bound = (
621-
window_bound[0] - padding, # minx
622-
window_bound[1] - padding, # miny
623-
window_bound[2] + padding, # maxx
624-
window_bound[3] + padding, # maxy
625-
)
616+
# Convert list of bounding box tuples to nice rasterio.coords.BoundingBox class
617+
window_bounds = [
618+
rasterio.coords.BoundingBox(
619+
left=x0 - padding, bottom=y0 - padding, right=x1 + padding, top=y1 + padding
620+
)
621+
for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax
622+
]
626623

627-
window = rasterio.windows.from_bounds(
628-
*window_bound, transform=dataset.transform, precision=None
629-
).round_offsets()
624+
with xr.open_rasterio(
625+
filepath, chunks=None if out_shape is None else {}, cache=False
626+
) as dataset:
627+
print(f"Tiling: {filepath} ... ", end="")
630628

631-
# Read the raster according to the crop window
632-
array = dataset.read(
633-
indexes=list(range(1, dataset.count + 1)),
634-
masked=True,
635-
window=window,
636-
out_shape=out_shape,
637-
)
638-
assert array.ndim == 3 # check that we have shape like (1, height, width)
639-
assert array.shape[0] == 1 # channel-first (assuming only 1 channel)
640-
assert not 0 in array.shape # ensure no empty dimensions (invalid window)
629+
# Subset dataset according to window bound (wb)
630+
daarray_list = [
631+
dataset.sel(y=slice(wb.top, wb.bottom), x=slice(wb.left, wb.right))
632+
for wb in window_bounds
633+
]
634+
# Bilinear interpolate to new shape if out_shape is set
635+
if out_shape is not None:
636+
daarray_list = [
637+
dataset.interp(
638+
y=np.linspace(da.y[0], da.y[-1], num=out_shape[0]),
639+
x=np.linspace(da.x[0], da.x[-1], num=out_shape[1]),
640+
method="linear",
641+
)
642+
for da in daarray_list
643+
]
644+
daarray_stack = dask.array.stack(seq=daarray_list)
641645

642-
try:
643-
assert not array.mask.any() # check that there are no NAN values
644-
except AssertionError:
645-
# Replace pixels from another raster if available, else raise error
646-
if gapfill_raster_filepath is not None:
647-
with rasterio.open(gapfill_raster_filepath) as dataset2:
648-
window2 = rasterio.windows.from_bounds(
649-
*window_bound, transform=dataset2.transform, precision=None
650-
).round_offsets()
651-
652-
array2 = dataset2.read(
653-
indexes=list(range(1, dataset2.count + 1)),
654-
masked=True,
655-
window=window2,
656-
out_shape=array.shape[1:],
657-
)
658-
659-
np.copyto(
660-
dst=array, src=array2, where=array.mask
661-
) # fill in gaps where mask is True
662-
663-
# assert not array.mask.any() # ensure no NAN values after gapfill
664-
else:
665-
plt.imshow(array.data[0, :, :])
666-
plt.show()
667-
print(
668-
f"WARN: Tile has missing data, try passing in gapfill_raster_filepath"
669-
)
670-
671-
# assert array.shape[1] == array.shape[2] # check that height==width
672-
array_list.append(array.data.astype(dtype=np.float32))
646+
assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)
647+
assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)
648+
assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)
673649
print("done!")
674650

675-
return np.stack(arrays=array_list)
651+
with dask.diagnostics.ProgressBar(minimum=5.0):
652+
try:
653+
out_tiles = daarray_stack.compute().astype(dtype=np.float32)
654+
assert not np.isnan(out_tiles).any() # check that there are no NAN values
655+
except AssertionError:
656+
raise NotImplementedError("gapfilling on dask xarray not yet implemented")
657+
finally:
658+
return out_tiles
676659

677660

678661
# %%
@@ -712,7 +695,7 @@ def selective_tile(
712695
filepath="misc/REMA_100m_dem.tif",
713696
window_bounds=window_bounds_concat,
714697
padding=1000,
715-
gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif",
698+
# gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif",
716699
)
717700
print(rema.shape, rema.dtype)
718701

0 commit comments

Comments
 (0)