-
Notifications
You must be signed in to change notification settings - Fork 6
Speed up for fillnan #511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up for fillnan #511
Changes from all commits
3667d7e
965f2e6
b087c96
5fab9ae
cacb0a2
24e57ff
27c61be
5ccbc09
ba25cbc
b335200
dd33c29
8916c77
d235715
6133f37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,11 +4,15 @@ | |
| import numpy as np | ||
| import rasterio | ||
| import xarray as xr | ||
| from flopy.discretization.vertexgrid import VertexGrid | ||
| from scipy.interpolate import griddata | ||
| from scipy.ndimage import binary_dilation, distance_transform_edt | ||
| from scipy.spatial import cKDTree | ||
|
|
||
| import nlmod | ||
|
|
||
| from ..util import get_da_from_da_ds | ||
| from .shared import get_area | ||
| from .shared import get_area, is_structured, is_vertex | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -102,8 +106,7 @@ def ds_to_structured_grid( | |
| angrot=0.0, | ||
| method="nearest", | ||
| ): | ||
| """Resample a dataset (xarray) from a structured grid to a new dataset from a | ||
| different structured grid. | ||
| """Resample a dataset from a structured grid to a different structured grid. | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -233,9 +236,11 @@ def _set_angrot_attributes(extent, xorigin, yorigin, angrot, attrs): | |
| def fillnan_da_structured_grid(xar_in, method="nearest"): | ||
| """Fill not-a-number values in a structured grid, DataArray. | ||
|
|
||
| The fill values are determined using the 'nearest' method of the | ||
| scipy.interpolate.griddata function | ||
| The nans are replaced with the interpolated values using distance_transform_edt when | ||
| all cells have the same shape and method is nearest, otherwise the values are | ||
| interpolated using griddata. | ||
|
|
||
| Note that extrapolation is only applied if method is "nearest". | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -251,50 +256,69 @@ def fillnan_da_structured_grid(xar_in, method="nearest"): | |
| xar_out : xarray DataArray | ||
| DataArray without nan values. DataArray has 2 dimensions | ||
| (y and x) | ||
|
|
||
| Notes | ||
| ----- | ||
| can be slow if the xar_in is a large raster | ||
| """ | ||
| # check dimensions | ||
| # if "x" not in xar_in.dims or "y" not in xar_in.dims: | ||
| if xar_in.dims != ("y", "x"): | ||
| raise ValueError( | ||
| f"expected dataarray with dimensions ('y' and 'x'), got dimensions -> {xar_in.dims}" | ||
| "expected dataarray with dimensions ('y' and 'x'), got dimensions -> " | ||
| f"{xar_in.dims}" | ||
| ) | ||
|
|
||
| # get list of coordinates from all points in raster | ||
| mg = np.meshgrid(xar_in.x.data, xar_in.y.data) | ||
| points_all = np.vstack((mg[0].ravel(), mg[1].ravel())).T | ||
| # Create a deep copy to avoid modifying the original data | ||
| # Using manual DataArray construction because copy(deep=True) still shares memory | ||
| # in some xarray contexts when data is loaded from netCDF files | ||
| xar_out = xr.DataArray( | ||
| xar_in.values.copy(), | ||
| coords=xar_in.coords, | ||
| dims=xar_in.dims, | ||
| attrs=xar_in.attrs, | ||
| ) | ||
|
|
||
| if method == "nearest": | ||
| y = xar_in.coords["y"].values | ||
| x = xar_in.coords["x"].values | ||
| dy = np.abs(y[1:] - y[:-1]) | ||
| dx = np.abs(x[1:] - x[:-1]) | ||
|
|
||
| # get all values in DataArray | ||
| values_all = xar_in.data.flatten() | ||
| if np.allclose(dy, dy[0]) and np.allclose(dx, dx[0]): | ||
| sampling = None if np.isclose(dy[0], dx[0]) else (dy[0], dx[0]) | ||
|
|
||
| # get 1d arrays with only values where DataArray is not nan | ||
| mask1 = ~np.isnan(values_all) | ||
| points_in = points_all[np.where(mask1)[0]] | ||
| values_in = values_all[np.where(mask1)[0]] | ||
| idx = distance_transform_edt( | ||
| input=xar_in.isnull(), | ||
| sampling=sampling, | ||
| return_distances=False, | ||
| return_indices=True, | ||
| ) | ||
| xar_out.values = xar_in.values[tuple(idx)] | ||
|
|
||
| # get value for all nan values | ||
| values_out = griddata(points_in, values_in, points_all, method=method) | ||
| arr_out = values_out.reshape(xar_in.shape) | ||
| return xar_out | ||
|
|
||
| # create DataArray without nan values | ||
| xar_out = xr.DataArray( | ||
| arr_out, | ||
| dims=("y", "x"), | ||
| coords={"x": xar_in.x.data, "y": xar_in.y.data}, | ||
| ) | ||
| # xar_out = xar_in.rio.interpolate_na(method=method) | ||
| xg, yg = np.meshgrid(xar_in.x.values, xar_in.y.values) | ||
|
|
||
| is_invalid = np.isnan(xar_in) | ||
| points_out = np.column_stack((xg[is_invalid], yg[is_invalid])) | ||
|
|
||
| # Get coordinates and values of values for griddata | ||
| if method in ("nearest", "linear"): | ||
| # We can cheaply isolate the neighboring cells and only pass that | ||
| # to griddata. Spline uses thicker outline of nan areas. | ||
| is_valid = binary_dilation(is_invalid) & ~is_invalid | ||
| else: | ||
| is_valid = ~is_invalid | ||
|
|
||
| points_in = np.column_stack((xg[is_valid], yg[is_valid])) | ||
| values_in = xar_in.values[is_valid] | ||
|
|
||
| xar_out.values[is_invalid] = griddata( | ||
| points_in, values_in, points_out, method=method | ||
| ) | ||
| return xar_out | ||
|
|
||
|
|
||
| def fillnan_da_vertex_grid(xar_in, ds=None, x=None, y=None, method="nearest"): | ||
| """Fill not-a-number values in a vertex grid, DataArray. | ||
|
|
||
| The fill values are determined using the 'nearest' method of the | ||
| scipy.interpolate.griddata function | ||
| Note that extrapolation is only applied if method is "nearest". | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -319,43 +343,76 @@ def fillnan_da_vertex_grid(xar_in, ds=None, x=None, y=None, method="nearest"): | |
|
|
||
| Notes | ||
| ----- | ||
| can be slow if the xar_in is a large raster | ||
| If x is provided, x will be used over ds and the x coordinate part of xar_in. | ||
| If x is not provided, ds will be used to get the x coordinates. | ||
| If x is not provided and "x" is not in xar_in.coords, an error will be raised. | ||
| """ | ||
| if xar_in.dims != ("icell2d",): | ||
| if not is_vertex(xar_in): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this will allow layered vertex arrays to pass which is not what you want (and avoided for the structured grid). I did not think of this when recommending the functions in the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @OnnoEbbens I don't follow you here. The code raises an error for everything that is not a vertex grid, which is desired here in this vertex function, right? |
||
| raise ValueError( | ||
| f"expected dataarray with dimensions ('icell2d'), got dimensions -> {xar_in.dims}" | ||
| "expected dataarray with dimensions ('icell2d'), got dimensions -> " | ||
| f"{xar_in.dims}" | ||
| ) | ||
|
|
||
| # get list of coordinates from all points in raster | ||
| if x is None: | ||
| x = ds["x"].data | ||
| if y is None: | ||
| y = ds["y"].data | ||
|
|
||
| xyi = np.column_stack((x, y)) | ||
|
|
||
| # fill nan values in DataArray | ||
| values_all = xar_in.data | ||
|
|
||
| # get 1d arrays with only values where DataArray is not nan | ||
| mask1 = ~np.isnan(values_all) | ||
| xyi_in = xyi[mask1] | ||
| values_in = values_all[mask1] | ||
| # Create a deep copy to avoid modifying the original data | ||
| # Using manual DataArray construction because copy(deep=True) still shares memory | ||
| # in some xarray contexts when data is loaded from netCDF files | ||
| xar_out = xr.DataArray( | ||
| xar_in.values.copy(), | ||
| coords=xar_in.coords, | ||
| dims=xar_in.dims, | ||
| attrs=xar_in.attrs, | ||
| ) | ||
|
|
||
| # get value for all nan values | ||
| values_out = griddata(xyi_in, values_in, xyi, method=method) | ||
| if x is not None: | ||
| pass | ||
| elif x is None and ds is not None: | ||
| x = ds["x"].values | ||
| elif x is None and "x" in xar_in.coords: | ||
| x = xar_in.coords["x"].values | ||
| else: | ||
| raise ValueError("x or ds must be provided to get x coordinates") | ||
| if y is not None: | ||
| pass | ||
| elif y is None and ds is not None: | ||
| y = ds["y"].values | ||
| elif y is None and "y" in xar_in.coords: | ||
| y = xar_in.coords["y"].values | ||
| else: | ||
| raise ValueError("y or ds must be provided to get y coordinates") | ||
|
|
||
| is_invalid = np.isnan(xar_out) | ||
| points_out = np.column_stack((x[is_invalid], y[is_invalid])) | ||
|
|
||
| if method in ("nearest", "linear") and ds is not None: | ||
| # We can cheaply isolate the neighboring cells and only pass those | ||
| # to griddata. Similar to for structured grids: | ||
| # is_valid = binary_dilation(is_invalid) & ~is_invalid | ||
| vertices = nlmod.grid.get_vertices_from_ds(ds) | ||
| cell2d = nlmod.grid.get_cell2d_from_ds(ds) | ||
| mg = VertexGrid(vertices=vertices, cell2d=cell2d) | ||
| _cell_connections = mg.neighbors() | ||
| cell_connections = {k: _cell_connections[k] for k in np.where(is_invalid)[0]} | ||
| unicons = set().union(*cell_connections.values()) - set(cell_connections.keys()) | ||
|
|
||
| is_valid = np.zeros_like(xar_in, dtype=bool) | ||
| is_valid[list(unicons)] = True | ||
| else: | ||
| is_valid = ~is_invalid | ||
|
|
||
| # create DataArray without nan values | ||
| xar_out = xr.DataArray(values_out, dims=("icell2d")) | ||
| points_in = np.column_stack((x[is_valid], y[is_valid])) | ||
| values_in = xar_out.values[is_valid] | ||
|
|
||
| # get value for all nan value | ||
| xar_out.values[is_invalid] = griddata( | ||
| points_in, values_in, points_out, method=method | ||
| ) | ||
| return xar_out | ||
|
|
||
|
|
||
| def fillnan_da(da, ds=None, method="nearest"): | ||
| """Fill not-a-number values in a DataArray. | ||
|
|
||
| The fill values are determined using the 'nearest' method of the | ||
| scipy.interpolate.griddata function | ||
| Note that extrapolation is only applied if method is "nearest". | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -375,11 +432,12 @@ def fillnan_da(da, ds=None, method="nearest"): | |
| ----- | ||
| can be slow if the xar_in is a large raster | ||
| """ | ||
| if len(da.shape) > 1 and len(da.y) == da.shape[-2] and len(da.x) == da.shape[-1]: | ||
| # the dataraary is structured | ||
| if is_structured(da): | ||
| return fillnan_da_structured_grid(da, method=method) | ||
| else: | ||
| elif is_vertex(da): | ||
| return fillnan_da_vertex_grid(da, ds, method=method) | ||
| else: | ||
| raise NotImplementedError("Unsupported grid type") | ||
|
|
||
|
|
||
| def vertex_da_to_ds(da, ds, method="nearest"): | ||
|
|
@@ -408,13 +466,14 @@ def vertex_da_to_ds(da, ds, method="nearest"): | |
|
|
||
| if "gridtype" in ds.attrs and ds.gridtype == "vertex": | ||
| if len(da.dims) == 1: | ||
| xi = list(zip(ds.x.values, ds.y.values)) | ||
| xi = list(zip(ds.x.values, ds.y.values, strict=True)) | ||
| z = griddata(points, da.values, xi, method=method) | ||
| coords = {"icell2d": ds.icell2d} | ||
| return xr.DataArray(z, dims="icell2d", coords=coords) | ||
| else: | ||
| raise NotImplementedError( | ||
| "Resampling from multidmensional vertex da to vertex ds not yet supported" | ||
| "Resampling from multidimensional vertex da to vertex ds not yet " | ||
| "supported" | ||
| ) | ||
|
|
||
| xg, yg = np.meshgrid(ds.x, ds.y) | ||
|
|
@@ -574,7 +633,7 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.nan): | |
| if "grid_mapping" in da_out.encoding: | ||
| del da_out.encoding["grid_mapping"] | ||
|
|
||
| # remove the long_name, standard_name and units attributes of the x and y coordinates | ||
| # remove the long_name, standard_name and units attributes of x and y coordinates | ||
| for coord in ["x", "y"]: | ||
| if coord not in da_out.coords: | ||
| continue | ||
|
|
@@ -586,6 +645,7 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.nan): | |
|
|
||
|
|
||
| def extent_to_polygon(extent): | ||
| """Convert an extent to a shapely Polygon.""" | ||
| logger.warning( | ||
| "nlmod.resample.extent_to_polygon is deprecated. " | ||
| "Use nlmod.util.extent_to_polygon instead." | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.