Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 122 additions & 62 deletions nlmod/dims/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import numpy as np
import rasterio
import xarray as xr
from flopy.discretization.vertexgrid import VertexGrid
from scipy.interpolate import griddata
from scipy.ndimage import binary_dilation, distance_transform_edt
from scipy.spatial import cKDTree

import nlmod

from ..util import get_da_from_da_ds
from .shared import get_area
from .shared import get_area, is_structured, is_vertex

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,8 +106,7 @@ def ds_to_structured_grid(
angrot=0.0,
method="nearest",
):
"""Resample a dataset (xarray) from a structured grid to a new dataset from a
different structured grid.
"""Resample a dataset from a structured grid to a different structured grid.

Parameters
----------
Expand Down Expand Up @@ -233,9 +236,11 @@ def _set_angrot_attributes(extent, xorigin, yorigin, angrot, attrs):
def fillnan_da_structured_grid(xar_in, method="nearest"):
"""Fill not-a-number values in a structured grid, DataArray.

The fill values are determined using the 'nearest' method of the
scipy.interpolate.griddata function
The nans are replaced with the interpolated values using distance_transform_edt when
all cells have the same shape and method is nearest, otherwise the values are
interpolated using griddata.

Note that extrapolation is only applied if method is "nearest".

Parameters
----------
Expand All @@ -251,50 +256,69 @@ def fillnan_da_structured_grid(xar_in, method="nearest"):
xar_out : xarray DataArray
DataArray without nan values. DataArray has 2 dimensions
(y and x)

Notes
-----
can be slow if the xar_in is a large raster
"""
# check dimensions
# if "x" not in xar_in.dims or "y" not in xar_in.dims:
if xar_in.dims != ("y", "x"):
raise ValueError(
f"expected dataarray with dimensions ('y' and 'x'), got dimensions -> {xar_in.dims}"
"expected dataarray with dimensions ('y' and 'x'), got dimensions -> "
f"{xar_in.dims}"
)

# get list of coordinates from all points in raster
mg = np.meshgrid(xar_in.x.data, xar_in.y.data)
points_all = np.vstack((mg[0].ravel(), mg[1].ravel())).T
# Create a deep copy to avoid modifying the original data
# Using manual DataArray construction because copy(deep=True) still shares memory
# in some xarray contexts when data is loaded from netCDF files
xar_out = xr.DataArray(
xar_in.values.copy(),
coords=xar_in.coords,
dims=xar_in.dims,
attrs=xar_in.attrs,
)

if method == "nearest":
y = xar_in.coords["y"].values
x = xar_in.coords["x"].values
dy = np.abs(y[1:] - y[:-1])
dx = np.abs(x[1:] - x[:-1])

# get all values in DataArray
values_all = xar_in.data.flatten()
if np.allclose(dy, dy[0]) and np.allclose(dx, dx[0]):
sampling = None if np.isclose(dy[0], dx[0]) else (dy[0], dx[0])

# get 1d arrays with only values where DataArray is not nan
mask1 = ~np.isnan(values_all)
points_in = points_all[np.where(mask1)[0]]
values_in = values_all[np.where(mask1)[0]]
idx = distance_transform_edt(
input=xar_in.isnull(),
sampling=sampling,
return_distances=False,
return_indices=True,
)
xar_out.values = xar_in.values[tuple(idx)]

# get value for all nan values
values_out = griddata(points_in, values_in, points_all, method=method)
arr_out = values_out.reshape(xar_in.shape)
return xar_out

# create DataArray without nan values
xar_out = xr.DataArray(
arr_out,
dims=("y", "x"),
coords={"x": xar_in.x.data, "y": xar_in.y.data},
)
# xar_out = xar_in.rio.interpolate_na(method=method)
xg, yg = np.meshgrid(xar_in.x.values, xar_in.y.values)

is_invalid = np.isnan(xar_in)
points_out = np.column_stack((xg[is_invalid], yg[is_invalid]))

# Get coordinates and values of values for griddata
if method in ("nearest", "linear"):
# We can cheaply isolate the neighboring cells and only pass that
# to griddata. Spline uses thicker outline of nan areas.
is_valid = binary_dilation(is_invalid) & ~is_invalid
else:
is_valid = ~is_invalid

points_in = np.column_stack((xg[is_valid], yg[is_valid]))
values_in = xar_in.values[is_valid]

xar_out.values[is_invalid] = griddata(
points_in, values_in, points_out, method=method
)
return xar_out


def fillnan_da_vertex_grid(xar_in, ds=None, x=None, y=None, method="nearest"):
"""Fill not-a-number values in a vertex grid, DataArray.

The fill values are determined using the 'nearest' method of the
scipy.interpolate.griddata function
Note that extrapolation is only applied if method is "nearest".

Parameters
----------
Expand All @@ -319,43 +343,76 @@ def fillnan_da_vertex_grid(xar_in, ds=None, x=None, y=None, method="nearest"):

Notes
-----
can be slow if the xar_in is a large raster
If x is provided, x will be used over ds and the x coordinate part of xar_in.
If x is not provided, ds will be used to get the x coordinates.
If x is not provided and "x" is not in xar_in.coords, an error will be raised.
"""
if xar_in.dims != ("icell2d",):
if not is_vertex(xar_in):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this will allow layered vertex arrays to pass which is not what you want (and avoided for the structured grid). I did not think of this when recommending the functions in the shared module.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @OnnoEbbens I don't follow you here. The code raises an error for everything that is not a vertex grid, which is desired here in this vertex function, right?

raise ValueError(
f"expected dataarray with dimensions ('icell2d'), got dimensions -> {xar_in.dims}"
"expected dataarray with dimensions ('icell2d'), got dimensions -> "
f"{xar_in.dims}"
)

# get list of coordinates from all points in raster
if x is None:
x = ds["x"].data
if y is None:
y = ds["y"].data

xyi = np.column_stack((x, y))

# fill nan values in DataArray
values_all = xar_in.data

# get 1d arrays with only values where DataArray is not nan
mask1 = ~np.isnan(values_all)
xyi_in = xyi[mask1]
values_in = values_all[mask1]
# Create a deep copy to avoid modifying the original data
# Using manual DataArray construction because copy(deep=True) still shares memory
# in some xarray contexts when data is loaded from netCDF files
xar_out = xr.DataArray(
xar_in.values.copy(),
coords=xar_in.coords,
dims=xar_in.dims,
attrs=xar_in.attrs,
)

# get value for all nan values
values_out = griddata(xyi_in, values_in, xyi, method=method)
if x is not None:
pass
elif x is None and ds is not None:
x = ds["x"].values
elif x is None and "x" in xar_in.coords:
x = xar_in.coords["x"].values
else:
raise ValueError("x or ds must be provided to get x coordinates")
if y is not None:
pass
elif y is None and ds is not None:
y = ds["y"].values
elif y is None and "y" in xar_in.coords:
y = xar_in.coords["y"].values
else:
raise ValueError("y or ds must be provided to get y coordinates")

is_invalid = np.isnan(xar_out)
points_out = np.column_stack((x[is_invalid], y[is_invalid]))

if method in ("nearest", "linear") and ds is not None:
# We can cheaply isolate the neighboring cells and only pass those
# to griddata. Similar to for structured grids:
# is_valid = binary_dilation(is_invalid) & ~is_invalid
vertices = nlmod.grid.get_vertices_from_ds(ds)
cell2d = nlmod.grid.get_cell2d_from_ds(ds)
mg = VertexGrid(vertices=vertices, cell2d=cell2d)
_cell_connections = mg.neighbors()
cell_connections = {k: _cell_connections[k] for k in np.where(is_invalid)[0]}
unicons = set().union(*cell_connections.values()) - set(cell_connections.keys())

is_valid = np.zeros_like(xar_in, dtype=bool)
is_valid[list(unicons)] = True
else:
is_valid = ~is_invalid

# create DataArray without nan values
xar_out = xr.DataArray(values_out, dims=("icell2d"))
points_in = np.column_stack((x[is_valid], y[is_valid]))
values_in = xar_out.values[is_valid]

# get value for all nan value
xar_out.values[is_invalid] = griddata(
points_in, values_in, points_out, method=method
)
return xar_out


def fillnan_da(da, ds=None, method="nearest"):
"""Fill not-a-number values in a DataArray.

The fill values are determined using the 'nearest' method of the
scipy.interpolate.griddata function
Note that extrapolation is only applied if method is "nearest".

Parameters
----------
Expand All @@ -375,11 +432,12 @@ def fillnan_da(da, ds=None, method="nearest"):
-----
can be slow if the xar_in is a large raster
"""
if len(da.shape) > 1 and len(da.y) == da.shape[-2] and len(da.x) == da.shape[-1]:
# the dataraary is structured
if is_structured(da):
return fillnan_da_structured_grid(da, method=method)
else:
elif is_vertex(da):
return fillnan_da_vertex_grid(da, ds, method=method)
else:
raise NotImplementedError("Unsupported grid type")


def vertex_da_to_ds(da, ds, method="nearest"):
Expand Down Expand Up @@ -408,13 +466,14 @@ def vertex_da_to_ds(da, ds, method="nearest"):

if "gridtype" in ds.attrs and ds.gridtype == "vertex":
if len(da.dims) == 1:
xi = list(zip(ds.x.values, ds.y.values))
xi = list(zip(ds.x.values, ds.y.values, strict=True))
z = griddata(points, da.values, xi, method=method)
coords = {"icell2d": ds.icell2d}
return xr.DataArray(z, dims="icell2d", coords=coords)
else:
raise NotImplementedError(
"Resampling from multidmensional vertex da to vertex ds not yet supported"
"Resampling from multidimensional vertex da to vertex ds not yet "
"supported"
)

xg, yg = np.meshgrid(ds.x, ds.y)
Expand Down Expand Up @@ -574,7 +633,7 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.nan):
if "grid_mapping" in da_out.encoding:
del da_out.encoding["grid_mapping"]

# remove the long_name, standard_name and units attributes of the x and y coordinates
# remove the long_name, standard_name and units attributes of x and y coordinates
for coord in ["x", "y"]:
if coord not in da_out.coords:
continue
Expand All @@ -586,6 +645,7 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.nan):


def extent_to_polygon(extent):
"""Convert an extent to a shapely Polygon."""
logger.warning(
"nlmod.resample.extent_to_polygon is deprecated. "
"Use nlmod.util.extent_to_polygon instead."
Expand Down
Loading
Loading