Skip to content
5 changes: 4 additions & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
np.datetime64: "GMT_DATETIME",
}

# Load the GMT library outside the Session class to avoid repeated loading.
_libgmt = load_libgmt()


class Session:
"""
Expand Down Expand Up @@ -308,7 +311,7 @@ def get_libgmt_func(self, name, argtypes=None, restype=None):
<class 'ctypes.CDLL.__init__.<locals>._FuncPtr'>
"""
if not hasattr(self, "_libgmt"):
self._libgmt = load_libgmt()
self._libgmt = _libgmt
function = getattr(self._libgmt, name)
if argtypes is not None:
function.argtypes = argtypes
Expand Down
39 changes: 39 additions & 0 deletions pygmt/tests/test_clib_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest
from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt
from pygmt.clib.session import Session
from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError


Expand Down Expand Up @@ -207,6 +208,44 @@ def test_brokenlib_brokenlib_workinglib(self):
assert check_libgmt(load_libgmt(lib_fullnames=lib_fullnames)) is None


class TestLibgmtCount:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to put this unit test in test_session_management.py?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to keep it in test_clib_loading.py because this unit test is actually not related to session management.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I kinda wanted to move the test since test_clib_loading.py has 300+ lines of code, while test_session_management.py has <100 lines, and we're kinda checking that Session() doesn't reload libgmt here, but up to you 🙂

"""
Test that the GMT library is not repeatedly loaded in every session.
"""

loaded_libgmt = load_libgmt() # Load the GMT library and reuse it when necessary
counter = 0 # Global counter for how many times ctypes.CDLL is called

def _mock_ctypes_cdll_return(self, libname): # noqa: ARG002
"""
Mock ctypes.CDLL to count how many times the function is called.

If ctypes.CDLL is called, the counter increases by one.
"""
self.counter += 1 # Increase the counter
return self.loaded_libgmt

def test_libgmt_load_counter(self, monkeypatch):
"""
Make sure that the GMT library is not loaded in every session.
"""
# Monkeypatch the ctypes.CDLL function
monkeypatch.setattr(ctypes, "CDLL", self._mock_ctypes_cdll_return)

# Create two sessions and check the global counter
with Session() as lib:
_ = lib
with Session() as lib:
_ = lib
assert self.counter == 0 # ctypes.CDLL is not called after two sessions.

# Explicitly calling load_libgmt to make sure the mock function is correct
load_libgmt()
assert self.counter == 1
load_libgmt()
assert self.counter == 2


###############################################################################
# Test clib_full_names
@pytest.fixture(scope="module", name="gmt_lib_names")
Expand Down