diff --git a/nlmod/plot/plot.py b/nlmod/plot/plot.py index 176dc81b..0b5f4fc5 100644 --- a/nlmod/plot/plot.py +++ b/nlmod/plot/plot.py @@ -40,7 +40,11 @@ def modelgrid(ds, ax=None, **kwargs): _, ax = plt.subplots(figsize=(10, 10)) ax.axis("scaled") modelgrid = modelgrid_from_ds(ds) + extent = None if ax.get_autoscale_on() else ax.axis() modelgrid.plot(ax=ax, **kwargs) + if extent is not None: + ax.axis(extent) + return ax