Skip to content
2 changes: 1 addition & 1 deletion nlmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

NLMOD_DATADIR = os.path.join(os.path.dirname(__file__), "data")

from . import dims, gis, gwf, gwt, modpath, plot, read, sim, util
from . import config, dims, gis, gwf, gwt, modpath, plot, read, sim, util
from .dims import base, get_ds, grid, layers, resample, time, to_model_ds
from .util import download_mfbinaries
from .version import __version__, show_versions
133 changes: 108 additions & 25 deletions nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import pandas as pd
import xarray as xr
from dask.diagnostics import ProgressBar
from xarray.testing import assert_identical

from .config import NLMOD_CACHE_OPTIONS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,6 +197,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
with open(fname_pickle_cache, "rb") as f:
func_args_dic_cache = pickle.load(f)
pickle_check = True

except (pickle.UnpicklingError, ModuleNotFoundError):
logger.info("could not read pickle, not using cache")
pickle_check = False
Expand All @@ -216,30 +220,62 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):

if pickle_check:
# Ensure that the pickle pairs with the netcdf, see #66.
if nc_hash:
if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash:
with open(fname_cache, "rb") as myfile:
cache_bytes = myfile.read()
func_args_dic["_nc_hash"] = hashlib.sha256(
cache_bytes
).hexdigest()

if dataset is not None:
# Check the coords of the dataset argument
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)

# Check the data_vars of the dataset argument
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]:
# Check the coords of the dataset argument
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
else:
func_args_dic_cache.pop("_dataset_coords_hash", None)
logger.warning(
"cache -> dataset coordinates not checked, "
"disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)
if not NLMOD_CACHE_OPTIONS[
"explicit_dataset_coordinate_comparison"
]:
logger.warning(
"It is recommended to turn on "
"`explicit_dataset_coordinate_comparison` "
"in global config when hash check is turned off!"
)

if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]:
# Check the data_vars of the dataset argument
func_args_dic["_dataset_data_vars_hash"] = (
dask.base.tokenize(dict(dataset.data_vars))
)
else:
func_args_dic_cache.pop("_dataset_data_vars_hash", None)
logger.warning(
"cache -> dataset data vars not checked, "
"disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)

# check if cache was created with same function arguments as
# function call
argument_check = _same_function_arguments(
func_args_dic, func_args_dic_cache
)

# explicit check on input dataset coordinates and cached dataset
if NLMOD_CACHE_OPTIONS[
"explicit_dataset_coordinate_comparison"
] and isinstance(dataset, (xr.DataArray, xr.Dataset)):
b = _explicit_dataset_coordinate_comparison(dataset, cached_ds)
# update argument check
argument_check = argument_check and b

cached_ds = _check_for_data_array(cached_ds)
if modification_check and argument_check and pickle_check:
msg = f"using cached data -> {cachename}"
Expand Down Expand Up @@ -276,19 +312,33 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
result.to_netcdf(fname_cache)

# add netcdf hash to function arguments dic, see #66
if nc_hash:
if NLMOD_CACHE_OPTIONS["nc_hash"] and nc_hash:
with open(fname_cache, "rb") as myfile:
cache_bytes = myfile.read()
func_args_dic["_nc_hash"] = hashlib.sha256(cache_bytes).hexdigest()

# Add dataset argument hash to function arguments dic
if dataset is not None:
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
if NLMOD_CACHE_OPTIONS["dataset_coords_hash"]:
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
)
else:
logger.warning(
"cache -> not writing dataset coordinates hash to "
"pickle file, disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)
if NLMOD_CACHE_OPTIONS["dataset_data_vars_hash"]:
func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(
dict(dataset.data_vars)
)
else:
logger.warning(
"cache -> not writing dataset data vars hash to "
"pickle file, disabled in global config. See "
"`nlmod.config.NLMOD_CACHE_OPTIONS`."
)

# pickle function arguments
with open(fname_pickle_cache, "wb") as fpklz:
Expand Down Expand Up @@ -422,15 +472,14 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):


def _same_function_arguments(func_args_dic, func_args_dic_cache):
"""Checks if two dictionaries with function arguments are identical by
checking:
"""Checks if two dictionaries with function arguments are identical.

The following items are checked:
1. if they have the same keys
2. if the items have the same type
3. if the items have the same values (only possible for the types: int,
float, bool, str, bytes, list,
tuple, dict, np.ndarray,
xr.DataArray,
flopy.mf6.ModflowGwf).
3. if the items have the same values (only implemented for the types: int,
float, bool, str, bytes, list, tuple, dict, np.ndarray, xr.DataArray,
flopy.mf6.ModflowGwf).

Parameters
----------
Expand Down Expand Up @@ -744,7 +793,6 @@ def ds_contains(
if coords_2d or coords_3d:
coords.append("x")
coords.append("y")
datavars.append("area")
attrs.append("extent")
attrs.append("gridtype")

Expand Down Expand Up @@ -832,3 +880,38 @@ def ds_contains(
coords={k: ds.coords[k] for k in coords},
attrs={k: ds.attrs[k] for k in attrs},
)


def _explicit_dataset_coordinate_comparison(ds_in, ds_cache):
"""Perform explicit dataset coordinate comparison.

Uses `xarray.testing.assert_identical()`.

Parameters
----------
ds_in : xr.Dataset
Input dataset.
ds_cache : xr.Dataset
Cached dataset.

Returns
-------
bool
True if coordinates are identical, else False.

Raises
------
AssertionError
If the coordinates are not equal.
"""
logger.debug("cache -> performing explicit dataset coordinate comparison")
for coord in ds_cache.coords:
logger.debug(f"cache -> comparing coordinate {coord}")
try:
assert_identical(ds_in[coord], ds_cache[coord])
except AssertionError as e:
logger.debug(f"cache -> coordinate {coord} not equal")
logger.debug(e)
return False
logger.debug("cache -> all coordinates equal")
return True
79 changes: 79 additions & 0 deletions nlmod/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from contextlib import contextmanager

NLMOD_CACHE_OPTIONS = {
# compare hash for stored netcdf, default is True:
"nc_hash": True,
# compare hash for dataset coordinates, default is True:
"dataset_coords_hash": True,
# compare hash for dataset data variables, default is True:
"dataset_data_vars_hash": True,
# perform explicit comparison of dataset coordinates, default is False:
"explicit_dataset_coordinate_comparison": False,
}

_DEFAULT_CACHE_OPTIONS = {
"nc_hash": True,
"dataset_coords_hash": True,
"dataset_data_vars_hash": True,
"explicit_dataset_coordinate_comparison": False,
}


@contextmanager
def cache_options(**kwargs):
"""Context manager for nlmod cache options."""
set_options(**kwargs)
try:
yield get_options()
finally:
reset_options(list(kwargs.keys()))


def set_options(**kwargs):
"""
Set options for the nlmod package.

Parameters
----------
**kwargs : dict
Options to set.

"""
for key, value in kwargs.items():
if key in NLMOD_CACHE_OPTIONS:
NLMOD_CACHE_OPTIONS[key] = value
else:
raise ValueError(
f"Unknown option: {key}. Options are: "
f"{list(NLMOD_CACHE_OPTIONS.keys())}"
)


def get_options(key=None):
"""
Get options for the nlmod package.

Parameters
----------
key : str, optional
Option to get.

Returns
-------
dict or value
The options or the value of the requested option.

"""
if key is None:
return NLMOD_CACHE_OPTIONS
else:
return f"{key}: {NLMOD_CACHE_OPTIONS[key]}"


def reset_options(options=None):
"""Reset options to default."""
if options is None:
set_options(**_DEFAULT_CACHE_OPTIONS)
else:
for opt in options:
set_options(**{opt: _DEFAULT_CACHE_OPTIONS[opt]})
6 changes: 3 additions & 3 deletions nlmod/dims/attributes_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def get_encodings(
if np.issubdtype(da.dtype, np.character):
continue

assert (
"_FillValue" not in da.attrs
), f"Custom fillvalues are not supported. {varname} has a fillvalue set."
assert "_FillValue" not in da.attrs, (
f"Custom fillvalues are not supported. {varname} has a fillvalue set."
)

encoding = {
"zlib": True,
Expand Down
Loading