Skip to content
43 changes: 43 additions & 0 deletions nlmod/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from matplotlib.colors import ListedColormap, Normalize
from matplotlib.patches import Patch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from geopandas import GeoDataFrame
from shapely.geometry import Polygon

from ..dims.grid import modelgrid_from_ds
from ..dims.resample import get_affine_mod_to_world, get_extent
Expand Down Expand Up @@ -48,6 +50,47 @@ def modelgrid(ds, ax=None, **kwargs):
return ax


def modelextent(ds, ax=None, **kwargs):
"""Plot model extent.

Parameters
----------
ds : xarray.Dataset
The dataset containing the data.
ax : matplotlib.axes.Axes, optional
The axes object to plot on. If not provided, a new figure and axes will be
created.
**kwargs
Additional keyword arguments to pass to the boundary plot.

Returns
-------
ax : matplotlib.axes.Axes
axes object
"""
extent = xmin, xmax, ymin, ymax = get_extent(ds, rotated=True)
dx = 0.05 * (xmax - xmin)
dy = 0.05 * (ymax - ymin)
if ax is None:
_, ax = plt.subplots(figsize=(10, 10))
ax.axis("scaled")

ax.axis([xmin - dx, xmax + dx, ymin - dy, ymax + dy])
xy = [
(xmin, ymin),
(xmax, ymin),
(xmax, ymax),
(xmin, ymax),
(xmin, ymin),
]
gdf = GeoDataFrame(geometry=[Polygon(xy)])
extent = None if ax.get_autoscale_on() else ax.axis()
gdf.boundary.plot(ax=ax, **kwargs)
if extent is not None:
ax.axis(extent)
return ax


def facet_plot(
gwf,
ds,
Expand Down