Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 74 additions & 53 deletions nlmod/dims/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import rasterio
import xarray as xr
from scipy.interpolate import griddata
from scipy.ndimage import distance_transform_edt
from scipy.spatial import cKDTree

from ..util import get_da_from_da_ds
Expand Down Expand Up @@ -102,8 +103,7 @@ def ds_to_structured_grid(
angrot=0.0,
method="nearest",
):
"""Resample a dataset (xarray) from a structured grid to a new dataset from a
different structured grid.
"""Resample a dataset from a structured grid to a different structured grid.

Parameters
----------
Expand Down Expand Up @@ -233,9 +233,9 @@ def _set_angrot_attributes(extent, xorigin, yorigin, angrot, attrs):
def fillnan_da_structured_grid(xar_in, method="nearest"):
"""Fill not-a-number values in a structured grid, DataArray.

The fill values are determined using the 'nearest' method of the
scipy.interpolate.griddata function

The fill values are determined using the scipy.interpolate.griddata function.
distance_transform_edt is used if all cells have the same shape and method is
nearest.

Parameters
----------
Expand All @@ -251,42 +251,51 @@ def fillnan_da_structured_grid(xar_in, method="nearest"):
xar_out : xarray DataArray
DataArray without nan values. DataArray has 2 dimensions
(y and x)

Notes
-----
can be slow if the xar_in is a large raster
"""
xar_out = xar_in.copy()

if method == "nearest":
dim0 = xar_in.coords[xar_in.dims[0]].values
dim1 = xar_in.coords[xar_in.dims[1]].values
ddim0 = np.abs(dim0[1:] - dim0[:-1])
ddim1 = np.abs(dim1[1:] - dim1[:-1])

if np.allclose(ddim0, ddim0[0]) and np.allclose(ddim1, ddim1[0]):
sampling = None if np.isclose(ddim0[0], ddim1[0]) else (ddim0[0], ddim1[0])

idx = distance_transform_edt(
input=xar_in.isnull(),
sampling=sampling,
return_distances=False,
return_indices=True,
)
xar_out.values = xar_in.values[tuple(idx)]

return xar_out

# check dimensions
# if "x" not in xar_in.dims or "y" not in xar_in.dims:
if xar_in.dims != ("y", "x"):
raise ValueError(
f"expected dataarray with dimensions ('y' and 'x'), got dimensions -> {xar_in.dims}"
"expected dataarray with dimensions ('y' and 'x'), got dimensions -> "
f"{xar_in.dims}"
)

# get list of coordinates from all points in raster
mg = np.meshgrid(xar_in.x.data, xar_in.y.data)
points_all = np.vstack((mg[0].ravel(), mg[1].ravel())).T
xg, yg = np.meshgrid(xar_in.x.values, xar_in.y.values)
points_all = np.column_stack((xg.ravel(), yg.ravel()))

# get all values in DataArray
values_all = xar_in.data.flatten()
values_all = xar_in.values.flatten()

# get 1d arrays with only values where DataArray is not nan
mask1 = ~np.isnan(values_all)
points_in = points_all[np.where(mask1)[0]]
values_in = values_all[np.where(mask1)[0]]

# get value for all nan values
values_out = griddata(points_in, values_in, points_all, method=method)
arr_out = values_out.reshape(xar_in.shape)

# create DataArray without nan values
xar_out = xr.DataArray(
arr_out,
dims=("y", "x"),
coords={"x": xar_in.x.data, "y": xar_in.y.data},
)
# xar_out = xar_in.rio.interpolate_na(method=method)

mask = np.isnan(values_all)
points_in = points_all[~mask]
values_in = values_all[~mask]
points_out = points_all[mask]

# get value for nan values
values_all[mask] = griddata(points_in, values_in, points_out, method=method)
xar_out.values = values_all.reshape(xar_in.shape)
return xar_out


Expand Down Expand Up @@ -319,35 +328,45 @@ def fillnan_da_vertex_grid(xar_in, ds=None, x=None, y=None, method="nearest"):

Notes
-----
can be slow if the xar_in is a large raster
If x is provided, x will be used over ds and the x coordinate part of xar_in.
If x is not provided, ds will be used to get the x coordinates.
If x is not provided and "x" is not in xar_in.coords, an error will be raised.
"""
if xar_in.dims != ("icell2d",):
raise ValueError(
f"expected dataarray with dimensions ('icell2d'), got dimensions -> {xar_in.dims}"
"expected dataarray with dimensions ('icell2d'), got dimensions -> "
f"{xar_in.dims}"
)

# get list of coordinates from all points in raster
if x is None:
x = ds["x"].data
if y is None:
y = ds["y"].data

xyi = np.column_stack((x, y))
xar_out = xar_in.copy()

# fill nan values in DataArray
values_all = xar_in.data

# get 1d arrays with only values where DataArray is not nan
mask1 = ~np.isnan(values_all)
xyi_in = xyi[mask1]
values_in = values_all[mask1]
if x is not None:
pass
elif x is None and ds is not None:
x = ds["x"].values
elif x is None and "x" in xar_in.coords:
x = xar_in.coords["x"].values
else:
raise ValueError("x or ds must be provided to get x coordinates")
if y is not None:
pass
elif y is None and ds is not None:
y = ds["y"].values
elif y is None and "y" in xar_in.coords:
y = xar_in.coords["y"].values
else:
raise ValueError("y or ds must be provided to get y coordinates")

# get value for all nan values
values_out = griddata(xyi_in, values_in, xyi, method=method)
points_all = np.column_stack((x, y))
values_all = xar_out.values

# create DataArray without nan values
xar_out = xr.DataArray(values_out, dims=("icell2d"))
mask = np.isnan(values_all)
points_in = points_all[~mask]
points_out = points_all[mask]
values_in = values_all[~mask]

# get value for all nan value
xar_out.values[mask] = griddata(points_in, values_in, points_out, method=method)
return xar_out


Expand Down Expand Up @@ -408,13 +427,14 @@ def vertex_da_to_ds(da, ds, method="nearest"):

if "gridtype" in ds.attrs and ds.gridtype == "vertex":
if len(da.dims) == 1:
xi = list(zip(ds.x.values, ds.y.values))
xi = list(zip(ds.x.values, ds.y.values, strict=True))
z = griddata(points, da.values, xi, method=method)
coords = {"icell2d": ds.icell2d}
return xr.DataArray(z, dims="icell2d", coords=coords)
else:
raise NotImplementedError(
"Resampling from multidmensional vertex da to vertex ds not yet supported"
"Resampling from multidimensional vertex da to vertex ds not yet "
"supported"
)

xg, yg = np.meshgrid(ds.x, ds.y)
Expand Down Expand Up @@ -574,7 +594,7 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.nan):
if "grid_mapping" in da_out.encoding:
del da_out.encoding["grid_mapping"]

# remove the long_name, standard_name and units attributes of the x and y coordinates
# remove the long_name, standard_name and units attributes of x and y coordinates
for coord in ["x", "y"]:
if coord not in da_out.coords:
continue
Expand All @@ -586,6 +606,7 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.nan):


def extent_to_polygon(extent):
"""Convert an extent to a shapely Polygon."""
logger.warning(
"nlmod.resample.extent_to_polygon is deprecated. "
"Use nlmod.util.extent_to_polygon instead."
Expand Down
134 changes: 124 additions & 10 deletions tests/test_026_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,133 @@ def test_vertex_da_to_ds():


def test_fillnan_da():
# for a structured grid
"""Test fillnan_da function improvements from PR."""
# Test structured grid with uniform spacing (uses distance_transform_edt opt)
ds = get_structured_model_ds()
original = ds["top"].copy()

# Test single NaN - should use fast distance_transform_edt path
ds["top"][5, 5] = np.nan
top = nlmod.resample.fillnan_da(ds["top"], ds=ds)
assert not np.isnan(top[5, 5])
expected_value = original[5, 4].values # nearest neighbor value

top_nearest = nlmod.resample.fillnan_da(ds["top"], ds=ds, method="nearest")
assert not np.isnan(top_nearest[5, 5])
assert np.isclose(top_nearest[5, 5], expected_value, rtol=1e-10)

# Test linear method (should use griddata fallback)
top_linear = nlmod.resample.fillnan_da(ds["top"], ds=ds, method="linear")
assert not np.isnan(top_linear[5, 5])

# Test that original values are preserved where not NaN
mask_valid = ~ds["top"].isnull()
np.testing.assert_allclose(
top_nearest.where(mask_valid), original.where(mask_valid), equal_nan=True
)

# Test vertex grid with improved coordinate handling
ds_vertex = get_vertex_model_ds()
ds_vertex["top"][100] = np.nan

# Test with ds parameter (should extract x,y from ds)
top_vertex_ds = nlmod.resample.fillnan_da(ds_vertex["top"], ds=ds_vertex)
assert not np.isnan(top_vertex_ds[100])

# Test vertex grid with coordinates in DataArray
x_coords = ds_vertex["x"].values
y_coords = ds_vertex["y"].values
ds_vertex_coords = ds_vertex.copy()
ds_vertex_coords["top"] = ds_vertex_coords["top"].assign_coords(
x=("icell2d", x_coords), y=("icell2d", y_coords)
)
top_vertex_coords = nlmod.resample.fillnan_da(ds_vertex_coords["top"])
assert not np.isnan(top_vertex_coords[100])


def test_fillnan_da_vertex_grid_coordinates():
"""Test improved coordinate handling in fillnan_da_vertex_grid."""
ds_vertex = get_vertex_model_ds()
ds_vertex["top"][100] = np.nan

# Test with ds parameter
top_ds = nlmod.resample.fillnan_da_vertex_grid(ds_vertex["top"], ds=ds_vertex)
assert not np.isnan(top_ds[100])

# Test with explicit x,y coordinates
x_coords = ds_vertex["x"].values
y_coords = ds_vertex["y"].values
top_xy = nlmod.resample.fillnan_da_vertex_grid(
ds_vertex["top"], x=x_coords, y=y_coords
)
assert not np.isnan(top_xy[100])
assert np.isclose(top_ds[100], top_xy[100])

# Test with coordinates in DataArray
vertex_da_with_coords = ds_vertex["top"].assign_coords(
x=("icell2d", x_coords), y=("icell2d", y_coords)
)
top_coords = nlmod.resample.fillnan_da_vertex_grid(vertex_da_with_coords)
assert not np.isnan(top_coords[100])


def test_fillnan_da_uniform_vs_nonuniform():
"""Test optimization path selection for uniform vs non-uniform grids."""
ds = get_structured_model_ds()

# also for a vertex grid
ds = get_vertex_model_ds()
ds["top"][100] = np.nan
mask = ds["top"].isnull()
assert mask.any()
top = nlmod.resample.fillnan_da(ds["top"], ds=ds)
assert not top[mask].isnull().any()
# Create test data with known pattern
test_values = np.arange(ds["top"].size).reshape(ds["top"].shape)
ds["top"].values = test_values

# Add NaN in center
center_y, center_x = ds["top"].shape[0] // 2, ds["top"].shape[1] // 2
ds["top"][center_y, center_x] = np.nan

# Test uniform grid (should use distance_transform_edt)
result_uniform = nlmod.resample.fillnan_da(ds["top"], ds=ds, method="nearest")

# Create non-uniform grid by adjusting coordinates
ds_nonuniform = ds.copy()
x_coords = ds.x.values
x_coords[5:] += 10 # Make spacing non-uniform
ds_nonuniform = ds_nonuniform.assign_coords(x=x_coords)

# Test non-uniform grid (should use griddata)
result_nonuniform = nlmod.resample.fillnan_da(
ds_nonuniform["top"], ds=ds_nonuniform, method="nearest"
)

# Both should fill the NaN but may give different results
assert not np.isnan(result_uniform[center_y, center_x])
assert not np.isnan(result_nonuniform[center_y, center_x])


def test_fillnan_da_error_handling():
"""Test improved error handling."""
import pytest
import xarray as xr

# Test vertex grid with wrong dimensions
ds_vertex = get_vertex_model_ds()
wrong_vertex_da = ds_vertex["top"].rename({"icell2d": "wrong_dim"})

with pytest.raises(
ValueError, match="expected dataarray with dimensions \\('icell2d'\\)"
):
nlmod.resample.fillnan_da_vertex_grid(wrong_vertex_da, ds=ds_vertex)

# Test vertex grid without coordinates (improved error handling from PR)
# Create a clean DataArray without x,y coordinates
clean_vertex_da = xr.DataArray(
ds_vertex["top"].values,
dims=("icell2d",),
coords={"icell2d": ds_vertex.icell2d},
)

with pytest.raises(ValueError, match="x or ds must be provided"):
nlmod.resample.fillnan_da_vertex_grid(clean_vertex_da)

# Test y coordinate error
with pytest.raises(ValueError, match="y or ds must be provided"):
nlmod.resample.fillnan_da_vertex_grid(clean_vertex_da, x=ds_vertex["x"].values)


def test_interpolate_gdf_to_array():
Expand Down
Loading