Skip to content

Commit b5fe5e6

Browse files
committed
✅ Make selective_tile unit test more tacit and fix it
To make the data_prep.selective_tile more understandable, I've 'simplified' the doctest with a diagonal array, and actually used more geographically correct (corner-based) pixel bounds. Also updated test_data_prep.py to use the v0.7.0 release grid instead of v0.4.0. Issue with rasterio.open not having proper affine transformation on netcdf files solved via the most ludricrous method of all - importing xarray before rasterio... Will submit a bug report after this, but in the meantime, the code can stay mostly intact, phew! Next step is to possibly remove salem, and refactor the selective_tile function a bit more actually, by using xarray.open_rasterio and the xarray.DataArray.sel by xy slice method.
1 parent 9d8ebdf commit b5fe5e6

4 files changed

Lines changed: 44 additions & 59 deletions

File tree

data_prep.ipynb

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@
5050
"import yaml\n",
5151
"import zipfile\n",
5252
"\n",
53+
"# need to import before rasterio\n",
54+
"import xarray as xr\n",
55+
"import salem\n",
56+
"\n",
5357
"import geopandas as gpd\n",
5458
"import pygmt as gmt\n",
5559
"import IPython.display\n",
@@ -64,9 +68,6 @@
6468
"import shapely.geometry\n",
6569
"import skimage.util.shape\n",
6670
"import tqdm\n",
67-
"import xarray as xr\n",
68-
"\n",
69-
"import salem\n",
7071
"\n",
7172
"print(\"Python :\", sys.version.split(\"\\n\")[0])\n",
7273
"print(\"Geopandas :\", gpd.__version__)\n",
@@ -1171,7 +1172,7 @@
11711172
"def selective_tile(\n",
11721173
" filepath: str,\n",
11731174
" window_bounds: list,\n",
1174-
" padding: int = 0,\n",
1175+
" padding: int = 0, # in projected coordinate system units\n",
11751176
" out_shape: tuple = None,\n",
11761177
" gapfill_raster_filepath: str = None,\n",
11771178
") -> np.ndarray:\n",
@@ -1182,21 +1183,21 @@
11821183
" some desired shape/resolution.\n",
11831184
"\n",
11841185
" >>> xr.DataArray(\n",
1185-
" ... data=np.random.RandomState(seed=42).rand(64).reshape(8, 8),\n",
1186-
" ... coords={\"x\": np.arange(8), \"y\": np.arange(8)},\n",
1187-
" ... dims=[\"x\", \"y\"],\n",
1186+
" ... data=np.flipud(m=np.diag(v=np.arange(8))).astype(dtype=np.float32),\n",
1187+
" ... coords={\"y\": np.linspace(7, 0, 8), \"x\": np.linspace(0, 7, 8)},\n",
1188+
" ... dims=[\"y\", \"x\"],\n",
11881189
" ... ).to_netcdf(path=\"/tmp/tmp_st.nc\", mode=\"w\")\n",
11891190
" >>> selective_tile(\n",
11901191
" ... filepath=\"/tmp/tmp_st.nc\",\n",
1191-
" ... window_bounds=[(1.0, 4.0, 3.0, 6.0), (2.0, 5.0, 4.0, 7.0)],\n",
1192+
" ... window_bounds=[(0.5, 0.5, 2.5, 2.5), (2.5, 1.5, 4.5, 3.5)],\n",
11921193
" ... )\n",
11931194
" Tiling: /tmp/tmp_st.nc ... done!\n",
1194-
" array([[[[0.18485446, 0.96958464],\n",
1195-
" [0.4951769 , 0.03438852]]],\n",
1195+
" array([[[[0., 2.],\n",
1196+
" [1., 0.]]],\n",
11961197
" <BLANKLINE>\n",
11971198
" <BLANKLINE>\n",
1198-
" [[[0.04522729, 0.32533032],\n",
1199-
" [0.96958464, 0.77513283]]]], dtype=float32)\n",
1199+
" [[[3., 0.],\n",
1200+
" [0., 0.]]]], dtype=float32)\n",
12001201
" >>> os.remove(\"/tmp/tmp_st.nc\")\n",
12011202
" \"\"\"\n",
12021203
" array_list = []\n",
@@ -1226,6 +1227,7 @@
12261227
" )\n",
12271228
" assert array.ndim == 3 # check that we have shape like (1, height, width)\n",
12281229
" assert array.shape[0] == 1 # channel-first (assuming only 1 channel)\n",
1230+
" assert not 0 in array.shape # ensure no empty dimensions (invalid window)\n",
12291231
"\n",
12301232
" try:\n",
12311233
" assert not array.mask.any() # check that there are no NAN values\n",
@@ -1256,7 +1258,7 @@
12561258
" f\"WARN: Tile has missing data, try passing in gapfill_raster_filepath\"\n",
12571259
" )\n",
12581260
"\n",
1259-
" # assert array.shape[0] == array.shape[1] # check that height==width\n",
1261+
" # assert array.shape[1] == array.shape[2] # check that height==width\n",
12601262
" array_list.append(array.data.astype(dtype=np.float32))\n",
12611263
" print(\"done!\")\n",
12621264
"\n",

data_prep.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
import yaml
3535
import zipfile
3636

37+
# need to import before rasterio
38+
import xarray as xr
39+
import salem
40+
3741
import geopandas as gpd
3842
import pygmt as gmt
3943
import IPython.display
@@ -48,9 +52,6 @@
4852
import shapely.geometry
4953
import skimage.util.shape
5054
import tqdm
51-
import xarray as xr
52-
53-
import salem
5455

5556
print("Python :", sys.version.split("\n")[0])
5657
print("Geopandas :", gpd.__version__)
@@ -581,7 +582,7 @@ def get_window_bounds(
581582
def selective_tile(
582583
filepath: str,
583584
window_bounds: list,
584-
padding: int = 0,
585+
padding: int = 0, # in projected coordinate system units
585586
out_shape: tuple = None,
586587
gapfill_raster_filepath: str = None,
587588
) -> np.ndarray:
@@ -592,21 +593,21 @@ def selective_tile(
592593
some desired shape/resolution.
593594
594595
>>> xr.DataArray(
595-
... data=np.random.RandomState(seed=42).rand(64).reshape(8, 8),
596-
... coords={"x": np.arange(8), "y": np.arange(8)},
597-
... dims=["x", "y"],
596+
... data=np.flipud(m=np.diag(v=np.arange(8))).astype(dtype=np.float32),
597+
... coords={"y": np.linspace(7, 0, 8), "x": np.linspace(0, 7, 8)},
598+
... dims=["y", "x"],
598599
... ).to_netcdf(path="/tmp/tmp_st.nc", mode="w")
599600
>>> selective_tile(
600601
... filepath="/tmp/tmp_st.nc",
601-
... window_bounds=[(1.0, 4.0, 3.0, 6.0), (2.0, 5.0, 4.0, 7.0)],
602+
... window_bounds=[(0.5, 0.5, 2.5, 2.5), (2.5, 1.5, 4.5, 3.5)],
602603
... )
603604
Tiling: /tmp/tmp_st.nc ... done!
604-
array([[[[0.18485446, 0.96958464],
605-
[0.4951769 , 0.03438852]]],
605+
array([[[[0., 2.],
606+
[1., 0.]]],
606607
<BLANKLINE>
607608
<BLANKLINE>
608-
[[[0.04522729, 0.32533032],
609-
[0.96958464, 0.77513283]]]], dtype=float32)
609+
[[[3., 0.],
610+
[0., 0.]]]], dtype=float32)
610611
>>> os.remove("/tmp/tmp_st.nc")
611612
"""
612613
array_list = []
@@ -636,6 +637,7 @@ def selective_tile(
636637
)
637638
assert array.ndim == 3 # check that we have shape like (1, height, width)
638639
assert array.shape[0] == 1 # channel-first (assuming only 1 channel)
640+
assert not 0 in array.shape # ensure no empty dimensions (invalid window)
639641

640642
try:
641643
assert not array.mask.any() # check that there are no NAN values
@@ -666,7 +668,7 @@ def selective_tile(
666668
f"WARN: Tile has missing data, try passing in gapfill_raster_filepath"
667669
)
668670

669-
# assert array.shape[0] == array.shape[1] # check that height==width
671+
# assert array.shape[1] == array.shape[2] # check that height==width
670672
array_list.append(array.data.astype(dtype=np.float32))
671673
print("done!")
672674

features/steps/test_data_prep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_a_raster_grid(context, dataset_type, raster_grid):
6262
context.raster_grid = raster_grid
6363
context.filepath = os.path.join(dataset_type, raster_grid)
6464
url = (
65-
f"https://github.com/weiji14/deepbedmap/releases/download/v0.4.0/{raster_grid}"
65+
f"https://github.com/weiji14/deepbedmap/releases/download/v0.7.0/{raster_grid}"
6666
)
6767
context.data_prep.download_to_path(path=context.filepath, url=url)
6868

test_ipynb.ipynb

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -170,43 +170,26 @@
170170
"ok\n",
171171
"Trying:\n",
172172
" xr.DataArray(\n",
173-
" data=np.random.RandomState(seed=42).rand(64).reshape(8, 8),\n",
174-
" coords={\"x\": np.arange(8), \"y\": np.arange(8)},\n",
175-
" dims=[\"x\", \"y\"],\n",
173+
" data=np.flipud(m=np.diag(v=np.arange(8))).astype(dtype=np.float32),\n",
174+
" coords={\"y\": np.linspace(7, 0, 8), \"x\": np.linspace(0, 7, 8)},\n",
175+
" dims=[\"y\", \"x\"],\n",
176176
" ).to_netcdf(path=\"/tmp/tmp_st.nc\", mode=\"w\")\n",
177177
"Expecting nothing\n",
178178
"ok\n",
179179
"Trying:\n",
180180
" selective_tile(\n",
181181
" filepath=\"/tmp/tmp_st.nc\",\n",
182-
" window_bounds=[(1.0, 4.0, 3.0, 6.0), (2.0, 5.0, 4.0, 7.0)],\n",
182+
" window_bounds=[(0.5, 0.5, 2.5, 2.5), (2.5, 1.5, 4.5, 3.5)],\n",
183183
" )\n",
184184
"Expecting:\n",
185185
" Tiling: /tmp/tmp_st.nc ... done!\n",
186-
" array([[[[0.18485446, 0.96958464],\n",
187-
" [0.4951769 , 0.03438852]]],\n",
188-
" <BLANKLINE>\n",
189-
" <BLANKLINE>\n",
190-
" [[[0.04522729, 0.32533032],\n",
191-
" [0.96958464, 0.77513283]]]], dtype=float32)\n",
192-
"**********************************************************************\n",
193-
"File \"data_prep\", line 627, in data_prep.selective_tile\n",
194-
"Failed example:\n",
195-
" selective_tile(\n",
196-
" filepath=\"/tmp/tmp_st.nc\",\n",
197-
" window_bounds=[(1.0, 4.0, 3.0, 6.0), (2.0, 5.0, 4.0, 7.0)],\n",
198-
" )\n",
199-
"Expected:\n",
200-
" Tiling: /tmp/tmp_st.nc ... done!\n",
201-
" array([[[[0.18485446, 0.96958464],\n",
202-
" [0.4951769 , 0.03438852]]],\n",
186+
" array([[[[0., 2.],\n",
187+
" [1., 0.]]],\n",
203188
" <BLANKLINE>\n",
204189
" <BLANKLINE>\n",
205-
" [[[0.04522729, 0.32533032],\n",
206-
" [0.96958464, 0.77513283]]]], dtype=float32)\n",
207-
"Got:\n",
208-
" Tiling: /tmp/tmp_st.nc ... done!\n",
209-
" array([], shape=(2, 1, 0, 2), dtype=float32)\n",
190+
" [[[3., 0.],\n",
191+
" [0., 0.]]]], dtype=float32)\n",
192+
"ok\n",
210193
"Trying:\n",
211194
" os.remove(\"/tmp/tmp_st.nc\")\n",
212195
"Expecting nothing\n",
@@ -241,19 +224,17 @@
241224
"2 items had no tests:\n",
242225
" data_prep\n",
243226
" data_prep.parse_datalist\n",
244-
"6 items passed all tests:\n",
227+
"7 items passed all tests:\n",
245228
" 6 tests in data_prep.ascii_to_xyz\n",
246229
" 3 tests in data_prep.check_sha256\n",
247230
" 3 tests in data_prep.download_to_path\n",
248231
" 2 tests in data_prep.get_region\n",
249232
" 3 tests in data_prep.get_window_bounds\n",
233+
" 3 tests in data_prep.selective_tile\n",
250234
" 5 tests in data_prep.xyz_to_grid\n",
251-
"**********************************************************************\n",
252-
"1 items had failures:\n",
253-
" 1 of 3 in data_prep.selective_tile\n",
254235
"25 tests in 9 items.\n",
255-
"24 passed and 1 failed.\n",
256-
"***Test Failed*** 1 failures.\n"
236+
"25 passed and 0 failed.\n",
237+
"Test passed.\n"
257238
]
258239
}
259240
],

0 commit comments

Comments
 (0)