Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Bug Report
description: File a bug report
title: "[Bug]: "
title: "[bug]: "
labels: ["bug"]

body:
Expand Down
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE/feature_request.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name: Feature Request
description: Suggest an enhancement or new feature
title: "[feature]: "
labels: ["feature"]

body:
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
- Custom `__reduce__` methods, allowing solvers to be serialized ([#38](https://github.com/NatLabRockies/scikit-sundae/pull/38))

### Optimizations
None.

### Bug Fixes
- Address memory leak with raised exceptions caused by persisting solver/data objects ([#51](https://github.com/NatLabRockies/scikit-sundae/pull/51))
- Fixes issue where `jacfn` is ignored when using `sparse` linear solver ([#48](https://github.com/NatLabRockies/scikit-sundae/pull/48))
- Ensures exception propagations work correctly with numpy 2.4 release ([#41](https://github.com/NatLabRockies/scikit-sundae/pull/41))

### Breaking Changes
None.

### Chores
- Move to using `ruff` for linting and start tracking `Cython` addition issue ([#45](https://github.com/NatLabRockies/scikit-sundae/pull/45))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/linear_solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Switching between SUNDIALS direct linear solvers in scikit-SUNDAE is straightfor
U = 0 # upper bandwidth
solver = CVODE(rhsfn, linsolver='band', lband=L, uband=U)

Ensure you provide both the lower bandwidth `lband` and upper bandwidth `uband` when using the `band` linear solver. Each bandwidth defines the LARGEST distance between a non-zero element and the main diagonal, on either side, as shown in the figure below. Forgetting to set either bandwidth will raise an error. If `lband + uband` matches the dimension `N` of the matrix, the performance of the `band` and `dense` linear solvers will be approximately the same.
Ensure you provide both the lower bandwidth `lband` and upper bandwidth `uband` when using the `band` linear solver. Each bandwidth defines the LARGEST distance between a non-zero element and the main diagonal, on either side, as shown in the figure below. Forgetting to set either bandwidth will raise an error. If `lband + uband + 1` matches the dimension `N` of the matrix, the performance of the `band` and `dense` linear solvers will be approximately the same.

.. figure:: figures/banded_jacobian.png
:width: 50%
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run_cleanup(_) -> None:
@nox.session(name='linter', python=False)
def run_ruff(session: nox.Session) -> None:
"""
Run ruff to check for linting errors.
Run ruff to check for linting errors

Use the optional 'format' or 'format-unsafe' arguments to run ruff with the
--fix or --unsafe-fixes option prior to the linter. You can also use 'stats'
Expand Down
6 changes: 6 additions & 0 deletions src/sksundae/_cy_common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ cimport numpy as np
# Extern cdef headers
from .c_sundials cimport * # Access to C types

# Propagate python exceptions or print SUNDIALS error messages
cdef _pyerr_handler()
cdef void _sunerr_handler(
int line, const char* func, const char* file, const char* msg, int err_code,
void* err_user_data, SUNContext ctx) except *

# Convert between N_Vector and numpy array
cdef svec2np(N_Vector nvec, np.ndarray[DTYPE_t, ndim=1] np_array)
cdef np2svec(np.ndarray[DTYPE_t, ndim=1] np_array, N_Vector nvec)
Expand Down
37 changes: 37 additions & 0 deletions src/sksundae/_cy_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
# Dependencies
cimport numpy as np

from cpython.exc cimport (
PyObject, PyErr_Occurred, PyErr_Fetch, PyErr_Restore,
# PyErr_GetRaisedException, PyErr_SetRaisedException,
)

# PyErr_Fetch and PyErr_Restore are deprecated at 3.12. When support for <3.12
# is dropped, replace with PyErr_GetRaisedException/PyErr_SetRaisedException.

# Extern cdef headers
from .c_sundials cimport * # Access to C types
from .c_nvector cimport * # Access to N_Vector functions
Expand Down Expand Up @@ -37,6 +45,35 @@ elif SUNDIALS_INT_TYPE == "long int":
from numpy import int64 as INT_TYPE


cdef _pyerr_handler():
"""Catch and re-raise Python exceptions in Cython code."""
cdef PyObject *errtype, *errvalue, *errtraceback

PyErr_Fetch(&errtype, &errvalue, &errtraceback)
PyErr_Restore(errtype, errvalue, errtraceback)

# Update to the following when support for Python <3.12 is dropped:
# cdef PyObject *exc = PyErr_GetRaisedException()
# PyErr_SetRaisedException(exc)
# raise <object> exc

raise <object> errvalue


cdef void _sunerr_handler(int line, const char* func, const char* file,
const char* msg, int err_code, void* err_user_data,
SUNContext ctx) except *:
"""Custom error handler for shorter messages (no line or file)."""

if PyErr_Occurred():
pass

else:
decoded_func = func.decode("utf-8")
decoded_msg = msg.decode("utf-8").replace(", ,", ",").strip()
print(f"\n[{decoded_func}, Error: {err_code}] {decoded_msg}\n")


cdef svec2np(N_Vector nvec, np.ndarray[DTYPE_t, ndim=1] np_array):
"""Fill a numpy array with values from an N_Vector."""
cdef sunrealtype* nv_ptr
Expand Down
92 changes: 34 additions & 58 deletions src/sksundae/_cy_cvode.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@ cimport numpy as np

from scipy import sparse as sp
from scipy.optimize._numdiff import group_columns
from cpython.exc cimport (
PyErr_Fetch, PyErr_NormalizeException,
PyObject, PyErr_CheckSignals, PyErr_Occurred, # PyErr_GetRaisedException,
)

# PyErr_Fetch and PyErr_NormalizeException are deprecated at 3.12. When support
# for <3.12 is dropped, replace with PyErr_GetRaisedException.
from cpython.exc cimport PyErr_Occurred

# Extern cdef headers
from .c_cvode cimport *
Expand Down Expand Up @@ -96,7 +90,7 @@ LSMESSAGES = {


cdef int _rhsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,
void* data) except? -1:
void* data) except -1:
"""Wraps 'rhsfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -114,7 +108,7 @@ cdef int _rhsfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,


cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, sunrealtype* ee,
void* data) except? -1:
void* data) except -1:
"""Wraps 'eventsfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -133,7 +127,7 @@ cdef int _eventsfn_wrapper(sunrealtype t, N_Vector yy, sunrealtype* ee,

cdef int _jacfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, SUNMatrix JJ,
void* data, N_Vector tmp1, N_Vector tmp2,
N_Vector tmp3) except? -1:
N_Vector tmp3) except -1:
"""Wraps 'jacfn' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -153,7 +147,7 @@ cdef int _jacfn_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, SUNMatrix JJ,

cdef int _psetup_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,
sunbooleantype jok, sunbooleantype* jcurPtr,
sunrealtype gamma, void* data) except? -1:
sunrealtype gamma, void* data) except -1:
"""Wraps 'psetup' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -176,7 +170,7 @@ cdef int _psetup_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,

cdef int _psolve_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rv,
N_Vector zv, sunrealtype gamma, sunrealtype delta,
int lr, void* data) except? -1:
int lr, void* data) except -1:
"""Wraps 'psolve' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -199,7 +193,7 @@ cdef int _psolve_wrapper(sunrealtype t, N_Vector yy, N_Vector yp, N_Vector rv,


cdef int _jvsetup_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,
void* data) except? -1:
void* data) except -1:
"""Wraps 'jvsetup' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -217,7 +211,7 @@ cdef int _jvsetup_wrapper(sunrealtype t, N_Vector yy, N_Vector yp,


cdef int _jvsolve_wrapper(N_Vector vv, N_Vector Jv, sunrealtype t, N_Vector yy,
N_Vector yp, void* data, N_Vector tmp) except? -1:
N_Vector yp, void* data, N_Vector tmp) except -1:
"""Wraps 'jvsolve' by converting between N_Vector and ndarray types."""

aux = <AuxData> data
Expand All @@ -237,27 +231,6 @@ cdef int _jvsolve_wrapper(N_Vector vv, N_Vector Jv, sunrealtype t, N_Vector yy,
return 0


cdef void _err_handler(int line, const char* func, const char* file,
const char* msg, int err_code, void* err_user_data,
SUNContext ctx) except *:
"""Custom error handler for shorter messages (no line or file)."""
cdef PyObject *errtype, *errvalue, *errtraceback

if PyErr_Occurred():
aux = <AuxData> err_user_data
# aux.pyerr = <object> PyErr_GetRaisedException()

PyErr_Fetch(&errtype, &errvalue, &errtraceback)
PyErr_NormalizeException(&errtype, &errvalue, &errtraceback)

aux.pyerr = <object> errvalue

else:
decoded_func = func.decode("utf-8")
decoded_msg = msg.decode("utf-8").replace(", ,", ",").strip()
print(f"\n[{decoded_func}, Error: {err_code}] {decoded_msg}\n")


cdef class AuxData:
"""
Auxiliary data.
Expand All @@ -278,7 +251,6 @@ cdef class AuxData:
cdef bint with_userdata
cdef bint is_constrained

cdef object pyerr # Exception
cdef object rhsfn # Callable
cdef object userdata # Any
cdef object eventsfn # Callable
Expand All @@ -289,7 +261,6 @@ cdef class AuxData:
cdef object jactimes # CVODEJacTimes

def __cinit__(self, sunindextype NEQ, object options):
self.pyerr = None
self.np_yy = np.empty(NEQ, DTYPE)
self.np_yp = np.empty(NEQ, DTYPE)

Expand Down Expand Up @@ -364,9 +335,7 @@ cdef class _cvLSSparseDQJac:
"""
cdef void* mem
cdef AuxData aux

cdef object groups # dict[int, np.ndarray[int]]
cdef object sparsity # sparse.csc_matrix, shape(NEQ, NEQ)

def __cinit__(self, AuxData aux):

Expand All @@ -380,7 +349,6 @@ cdef class _cvLSSparseDQJac:

self.aux = aux
self.groups = groups
self.sparsity = aux.sparsity

def __call__(
self,
Expand All @@ -397,7 +365,7 @@ cdef class _cvLSSparseDQJac:
cdef np.ndarray[DTYPE_t, ndim=1] diff, inc, inc_inv, ytemp, yptemp

aux = <AuxData> self.aux
sparsity = self.sparsity
sparsity = aux.sparsity

ytemp = y.copy()
yptemp = yp.copy()
Expand Down Expand Up @@ -454,6 +422,9 @@ cdef class _cvLSSparseDQJac:
"""
self.mem = mem

def __dealloc__(self):
self.mem = NULL


class CVODEResult(RichResult):
_order_keys = ["message", "success", "status", "t", "y", "i_events",
Expand Down Expand Up @@ -756,7 +727,7 @@ cdef class CVODE:

# 16) Set optional inputs
SUNContext_ClearErrHandlers(self.ctx)
SUNContext_PushErrHandler(self.ctx, _err_handler, <void*> self.aux)
SUNContext_PushErrHandler(self.ctx, _sunerr_handler, NULL)

cdef sunrealtype first_step = <sunrealtype> self._options["first_step"]
flag = CVodeSetInitStep(self.mem, first_step)
Expand Down Expand Up @@ -860,6 +831,9 @@ cdef class CVODE:
# 17) Advance solution in time
flag = CVode(self.mem, tt, self.yy, &tout, itask)

if PyErr_Occurred():
_pyerr_handler()

svec2np(self.yy, yy_tmp)

if flag == CV_ROOT_RETURN:
Expand All @@ -884,9 +858,11 @@ cdef class CVODE:

return result

cdef _normal_solve(self, np.ndarray[DTYPE_t, ndim=1] tspan,
np.ndarray[DTYPE_t, ndim=1] y0,
):
cdef _normal_solve(
self,
np.ndarray[DTYPE_t, ndim=1] tspan,
np.ndarray[DTYPE_t, ndim=1] y0,
):

cdef int ind
cdef int flag
Expand Down Expand Up @@ -918,6 +894,9 @@ cdef class CVODE:

flag = CVode(self.mem, tend, self.yy, &tt, CV_NORMAL)

if PyErr_Occurred():
_pyerr_handler()

svec2np(self.yy, yy_tmp)

if flag == CV_ROOT_RETURN:
Expand All @@ -937,11 +916,7 @@ cdef class CVODE:

ind += 1

if self.aux.pyerr is not None:
raise self.aux.pyerr
elif PyErr_CheckSignals() == -1:
return
elif stop:
if stop:
break

if self.aux.eventsfn:
Expand All @@ -963,9 +938,11 @@ cdef class CVODE:

return result

cdef _onestep_solve(self, np.ndarray[DTYPE_t, ndim=1] tspan,
np.ndarray[DTYPE_t, ndim=1] y0,
):
cdef _onestep_solve(
self,
np.ndarray[DTYPE_t, ndim=1] tspan,
np.ndarray[DTYPE_t, ndim=1] y0,
):

cdef int ind
cdef int flag
Expand Down Expand Up @@ -1001,6 +978,9 @@ cdef class CVODE:
while True:
flag = CVode(self.mem, tend, self.yy, &tt, CV_ONE_STEP)

if PyErr_Occurred():
_pyerr_handler()

svec2np(self.yy, yy_tmp)

if flag == CV_ROOT_RETURN:
Expand All @@ -1022,11 +1002,7 @@ cdef class CVODE:

ind += 1

if self.aux.pyerr is not None:
raise self.aux.pyerr
elif PyErr_CheckSignals() == -1:
return
elif stop:
if stop:
break

if self.aux.eventsfn:
Expand Down
Loading
Loading