Skip to content
2 changes: 1 addition & 1 deletion nlmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

NLMOD_DATADIR = os.path.join(os.path.dirname(__file__), "data")

from . import dims, gis, gwf, gwt, modpath, plot, read, sim, util
from . import config, dims, gis, gwf, gwt, modpath, plot, read, sim, util
from .dims import base, get_ds, grid, layers, resample, time, to_model_ds
from .util import download_mfbinaries
from .version import __version__, show_versions
121 changes: 97 additions & 24 deletions nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import pandas as pd
import xarray as xr
from dask.diagnostics import ProgressBar
from xarray.testing import assert_equal

from .config import NLMOD_CACHE_OPTIONS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,6 +197,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
with open(fname_pickle_cache, "rb") as f:
func_args_dic_cache = pickle.load(f)
pickle_check = True

except (pickle.UnpicklingError, ModuleNotFoundError):
logger.info("could not read pickle, not using cache")
pickle_check = False
Expand All @@ -216,30 +220,60 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):

if pickle_check:
# Ensure that the pickle pairs with the netcdf, see #66.
if nc_hash:
if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash:
with open(fname_cache, "rb") as myfile:
cache_bytes = myfile.read()
func_args_dic["_nc_hash"] = hashlib.sha256(
cache_bytes
).hexdigest()

if dataset is not None:
# Check the coords of the dataset argument
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)

# Check the data_vars of the dataset argument
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]:
# Check the coords of the dataset argument
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
else:
func_args_dic_cache.pop("_dataset_coords_hash", None)
logger.warning(
"cache -> dataset coordinates not checked, "
"disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)
if not NLMOD_CACHE_OPTIONS[
"explicit_dataset_coordinate_comparison"
]:
logger.warning(
"It is recommended to turn on "
"`explicit_dataset_coordinate_comparison` "
"in global config when hash check is turned off!"
)

if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]:
# Check the data_vars of the dataset argument
func_args_dic["_dataset_data_vars_hash"] = (
dask.base.tokenize(dict(dataset.data_vars))
)
else:
func_args_dic_cache.pop("_dataset_data_vars_hash", None)
logger.warning(
"cache -> dataset data vars not checked, "
"disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)

# check if cache was created with same function arguments as
# function call
argument_check = _same_function_arguments(
func_args_dic, func_args_dic_cache
)

# explicit check on input dataset coordinates and cached dataset
if NLMOD_CACHE_OPTIONS[
"explicit_dataset_coordinate_comparison"
] and isinstance(dataset, (xr.DataArray, xr.Dataset)):
_explicit_dataset_coordinate_comparison(dataset, cached_ds)

cached_ds = _check_for_data_array(cached_ds)
if modification_check and argument_check and pickle_check:
msg = f"using cached data -> {cachename}"
Expand Down Expand Up @@ -276,19 +310,33 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
result.to_netcdf(fname_cache)

# add netcdf hash to function arguments dic, see #66
if nc_hash:
if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash:
with open(fname_cache, "rb") as myfile:
cache_bytes = myfile.read()
func_args_dic["_nc_hash"] = hashlib.sha256(cache_bytes).hexdigest()

# Add dataset argument hash to function arguments dic
if dataset is not None:
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]:
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
else:
logger.warning(
"cache -> not writing dataset coordinates hash to "
"pickle file, disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)
if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]:
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
else:
logger.warning(
"cache -> not writing dataset data vars hash to "
"pickle file, disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)

# pickle function arguments
with open(fname_pickle_cache, "wb") as fpklz:
Expand Down Expand Up @@ -422,15 +470,14 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):


def _same_function_arguments(func_args_dic, func_args_dic_cache):
"""Checks if two dictionaries with function arguments are identical by
checking:
"""Checks if two dictionaries with function arguments are identical.

The following items are checked:
1. if they have the same keys
2. if the items have the same type
3. if the items have the same values (only possible for the types: int,
float, bool, str, bytes, list,
tuple, dict, np.ndarray,
xr.DataArray,
flopy.mf6.ModflowGwf).
3. if the items have the same values (only implemented for the types: int,
float, bool, str, bytes, list, tuple, dict, np.ndarray, xr.DataArray,
flopy.mf6.ModflowGwf).

Parameters
----------
Expand Down Expand Up @@ -832,3 +879,29 @@ def ds_contains(
coords={k: ds.coords[k] for k in coords},
attrs={k: ds.attrs[k] for k in attrs},
)


def _explicit_dataset_coordinate_comparison(ds_in, ds_cache):
"""Perform explicit dataset coordinate comparison.

Parameters
----------
ds_in : xr.Dataset
Input dataset.
ds_cache : xr.Dataset
Cached dataset.

Returns
-------
None

Raises
------
AssertionError
If the coordinates are not equal.
"""
logger.info("cache -> performing explicit dataset coordinate comparison")
for coord in ds_cache.coords:
logger.debug("cache -> comparing coordinate: %s", coord)
assert_equal(ds_in[coord], ds_cache[coord])
logger.debug("cache -> all coordinates equal")
79 changes: 79 additions & 0 deletions nlmod/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from contextlib import contextmanager

NLMOD_CACHE_OPTIONS = {
# compare hash for stored netcdf, default is True:
"nc_hash": True,
# compare hash for dataset coordinates, default is True:
"dataset_coords_hash": True,
# compare hash for dataset data variables, default is True:
"dataset_data_vars_hash": True,
# perform explicit comparison of dataset coordinates, default is False:
"explicit_dataset_coordinate_comparison": False,
}

_DEFAULT_CACHE_OPTIONS = {
"nc_hash": True,
"dataset_coords_hash": True,
"dataset_data_vars_hash": True,
"explicit_dataset_coordinate_comparison": False,
}


@contextmanager
def cache_options(**kwargs):
"""Context manager for nlmod cache options."""
set_options(**kwargs)
try:
yield get_options()
finally:
reset_options(list(kwargs.keys()))


def set_options(**kwargs):
"""
Set options for the nlmod package.

Parameters
----------
**kwargs : dict
Options to set.

"""
for key, value in kwargs.items():
if key in NLMOD_CACHE_OPTIONS:
NLMOD_CACHE_OPTIONS[key] = value
else:
raise ValueError(
f"Unknown option: {key}. Options are: "
f"{list(NLMOD_CACHE_OPTIONS.keys())}"
)


def get_options(key=None):
"""
Get options for the nlmod package.

Parameters
----------
key : str, optional
Option to get.

Returns
-------
dict or value
The options or the value of the requested option.

"""
if key is None:
return NLMOD_CACHE_OPTIONS
else:
return f"{key}: {NLMOD_CACHE_OPTIONS[key]}"


def reset_options(options=None):
"""Reset options to default."""
if options is None:
set_options(**_DEFAULT_CACHE_OPTIONS)
else:
for opt in options:
set_options(**{opt: _DEFAULT_CACHE_OPTIONS[opt]})
6 changes: 3 additions & 3 deletions nlmod/dims/attributes_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def get_encodings(
if np.issubdtype(da.dtype, np.character):
continue

assert (
"_FillValue" not in da.attrs
), f"Custom fillvalues are not supported. {varname} has a fillvalue set."
assert "_FillValue" not in da.attrs, (
f"Custom fillvalues are not supported. {varname} has a fillvalue set."
)

encoding = {
"zlib": True,
Expand Down
6 changes: 3 additions & 3 deletions nlmod/dims/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,9 +1564,9 @@ def aggregate_vector_per_cell(gdf, fields_methods, modelgrid=None):
else:
raise TypeError("cannot aggregate geometries of different types")
if bool({"length_weighted", "max_length"} & set(fields_methods.values())):
assert (
geom_types[0] == "LineString"
), "can only use length methods with line geometries"
assert geom_types[0] == "LineString", (
"can only use length methods with line geometries"
)
if bool({"area_weighted", "max_area"} & set(fields_methods.values())):
if ("Polygon" in geom_types) or ("MultiPolygon" in geom_types):
pass
Expand Down
6 changes: 3 additions & 3 deletions nlmod/dims/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ def split_layers_ds(
split_dict[lay0] = [1 / split_dict[lay0]] * split_dict[lay0]
elif hasattr(split_dict[lay0], "__iter__"):
# make sure the fractions add up to 1
assert np.isclose(
np.sum(split_dict[lay0]), 1
), f"Fractions for splitting layer '{lay0}' do not add up to 1."
assert np.isclose(np.sum(split_dict[lay0]), 1), (
f"Fractions for splitting layer '{lay0}' do not add up to 1."
)
split_dict[lay0] = split_dict[lay0] / np.sum(split_dict[lay0])
else:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions nlmod/dims/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,9 @@ def dataframe_to_flopy_timeseries(
append=False,
):
assert not df.isna().any(axis=None)
assert (
ds.time.dtype.kind == "M"
), "get recharge requires a datetime64[ns] time index"
assert ds.time.dtype.kind == "M", (
"get recharge requires a datetime64[ns] time index"
)
if ds is not None:
# set index to days after the start of the simulation
df = df.copy()
Expand Down
14 changes: 7 additions & 7 deletions nlmod/gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def vertex_da_to_gdf(
gdf : geopandas.GeoDataframe
geodataframe of one or more DataArrays.
"""
assert (
model_ds.gridtype == "vertex"
), f"expected model dataset with gridtype vertex, got {model_ds.gridtype}"
assert model_ds.gridtype == "vertex", (
f"expected model dataset with gridtype vertex, got {model_ds.gridtype}"
)

if isinstance(data_variables, str):
data_variables = [data_variables]
Expand Down Expand Up @@ -126,9 +126,9 @@ def struc_da_to_gdf(
gdf : geopandas.GeoDataframe
geodataframe of one or more DataArrays.
"""
assert (
model_ds.gridtype == "structured"
), f"expected model dataset with gridtype vertex, got {model_ds.gridtype}"
assert model_ds.gridtype == "structured", (
f"expected model dataset with gridtype vertex, got {model_ds.gridtype}"
)

if isinstance(data_variables, str):
data_variables = [data_variables]
Expand Down Expand Up @@ -474,7 +474,7 @@ def _break_down_dimension(
if add_dim_name:
name = f"{name}_{dim}"
if add_one_based_index:
name = f"{name}_{i+1}"
name = f"{name}_{i + 1}"
else:
name = f"{name}_{value}"

Expand Down
6 changes: 3 additions & 3 deletions nlmod/gwf/horizontal_flow_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def get_hfb_spd(gwf, linestrings, hydchr=1 / 100, depth=None, elevation=None):
spd : List of Tuple
Stress period data used to configure the hfb package of Flopy.
"""
assert (
sum([depth is None, elevation is None]) == 1
), "Use either depth or elevation argument"
assert sum([depth is None, elevation is None]) == 1, (
"Use either depth or elevation argument"
)

tops = np.concatenate((gwf.disv.top.array[None], gwf.disv.botm.array))
thick = tops[:-1] - tops[1:]
Expand Down
2 changes: 1 addition & 1 deletion nlmod/gwf/lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def lake_from_gdf(
lakeout = gdf.loc[mask, "lakeno"].iloc[0]
if not (gdf.loc[mask, "lakeno"] == lakeout).all():
raise ValueError(
f'expected single value of lakeno for lakeout {boundnameout}, got {gdf.loc[mask, "lakeno"]}'
f"expected single value of lakeno for lakeout {boundnameout}, got {gdf.loc[mask, 'lakeno']}"
)
assert lakeno != lakeout, "lakein and lakeout cannot be the same"

Expand Down
2 changes: 1 addition & 1 deletion nlmod/gwf/surface_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def distribute_cond_over_lays(
try:
first_active = np.where(idomain > 0)[0][0]
except IndexError:
warnings.warn(f"No active layers in {cellid}, " "returning NaNs.")
warnings.warn(f"No active layers in {cellid}, returning NaNs.")
return np.nan, np.nan
else:
first_active = 0
Expand Down
Loading