Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def cache_netcdf(
coords_2d=False,
coords_3d=False,
coords_time=False,
attrs_ds=False,
datavars=None,
coords=None,
attrs=None,
Expand Down Expand Up @@ -101,6 +102,8 @@ def cache_netcdf(
Shorthand for adding 3D coordinates. The default is False.
coords_time : bool, optional
Shorthand for adding time coordinates. The default is False.
attrs_ds : bool, optional
Shorthand for adding model dataset attributes. The default is False.
datavars : list, optional
List of data variables to check for. The default is an empty list.
coords : list, optional
Expand Down Expand Up @@ -139,6 +142,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
coords_2d=coords_2d,
coords_3d=coords_3d,
coords_time=coords_time,
attrs_ds=attrs_ds,
datavars=datavars,
coords=coords,
attrs=attrs,
Expand All @@ -156,6 +160,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
coords_2d=coords_2d,
coords_3d=coords_3d,
coords_time=coords_time,
attrs_ds=attrs_ds,
datavars=datavars,
coords=coords,
attrs=attrs,
Expand Down Expand Up @@ -632,6 +637,7 @@ def ds_contains(
coords_2d=False,
coords_3d=False,
coords_time=False,
attrs_ds=False,
datavars=None,
coords=None,
attrs=None,
Expand All @@ -650,6 +656,8 @@ def ds_contains(
Shorthand for adding 3D coordinates. The default is False.
coords_time : bool, optional
Shorthand for adding time coordinates. The default is False.
attrs_ds : bool, optional
Shorthand for adding model dataset attributes. The default is False.
datavars : list, optional
List of data variables to check for. The default is an empty list.
coords : list, optional
Expand All @@ -666,7 +674,10 @@ def ds_contains(
if ds is None:
msg = "No dataset provided"
raise ValueError(msg)
if not coords_2d and not coords_3d and not datavars and not coords and not attrs:
isdefault_args = not any(
[coords_2d, coords_3d, coords_time, attrs_ds, datavars, coords, attrs]
)
if isdefault_args:
return ds

isvertex = ds.attrs["gridtype"] == "vertex"
Expand Down Expand Up @@ -699,7 +710,9 @@ def ds_contains(
datavars.remove("delc")

if "angrot" in ds.attrs:
attrs.append("angrot")
# set by `nlmod.base.to_model_ds()` and `nlmod.dims.resample._set_angrot_attributes()`
attrs_angrot_required = ["angrot", "xorigin", "yorigin"]
attrs.extend(attrs_angrot_required)

if coords_3d:
coords.append("layer")
Expand All @@ -712,6 +725,11 @@ def ds_contains(
datavars.append("nstp")
datavars.append("tsmult")

if attrs_ds:
# set by `nlmod.base.to_model_ds()` and `nlmod.base.set_ds_attrs()`
attrs_ds_required = ["model_name", "mfversion", "created_on", "exe_name", "model_ws", "figdir", "cachedir", "transport"]
attrs.extend(attrs_ds_required)

# User-friendly error messages if missing from ds
if "northsea" in datavars and "northsea" not in ds.data_vars:
msg = "Northsea not in dataset. Run nlmod.read.rws.add_northsea() first."
Expand All @@ -721,7 +739,7 @@ def ds_contains(
if "time" not in ds.coords:
msg = "time not in dataset. Run nlmod.time.set_ds_time() first."
raise ValueError(msg)

# Check if time-coord is complete
time_attrs_required = ["start", "time_units"]

Expand All @@ -731,6 +749,12 @@ def ds_contains(
"Run nlmod.time.set_ds_time() to set time."
raise ValueError(msg)

if attrs_ds:
for attr in attrs_ds_required:
if attr not in ds.attrs:
msg = f"{attr} not in dataset.attrs. Run nlmod.set_ds_attrs() first."
raise ValueError(msg)

# User-unfriendly error messages
for datavar in datavars:
if datavar not in ds.data_vars:
Expand Down