Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@

REGISTRATIONS = ["GMT_GRID_NODE_REG", "GMT_GRID_PIXEL_REG"]

DTYPES = {
# Dictionary for mapping numpy dtypes to GMT data types.
DTYPES_NUMERIC = {
np.int8: "GMT_CHAR",
np.int16: "GMT_SHORT",
np.int32: "GMT_INT",
Expand All @@ -93,10 +94,14 @@
np.uint64: "GMT_ULONG",
np.float32: "GMT_FLOAT",
np.float64: "GMT_DOUBLE",
np.timedelta64: "GMT_LONG",
}
DTYPES_TEXT = {
np.str_: "GMT_TEXT",
np.datetime64: "GMT_DATETIME",
np.timedelta64: "GMT_LONG",
}
DTYPES = DTYPES_NUMERIC | DTYPES_TEXT

# Dictionary for storing the values of GMT constants.
GMT_CONSTANTS = {}

Expand Down Expand Up @@ -879,63 +884,59 @@ def _parse_constant(
integer_value = sum(self[part] for part in parts)
return integer_value

def _check_dtype_and_dim(self, array, ndim):
def _check_dtype_and_dim(self, array: np.ndarray, ndim: int) -> int:
"""
Check that a numpy array has the given number of dimensions and is a valid data
type.

Parameters
----------
array : numpy.ndarray
array
The array to be tested.
ndim : int
ndim
The desired number of array dimensions.

Returns
-------
gmt_type : int
gmt_type
The GMT constant value representing this data type.

Raises
------
GMTInvalidInput
If the array has the wrong number of dimensions or
is an unsupported data type.
If the array has the wrong number of dimensions or is an unsupported data
type.

Examples
--------

>>> import numpy as np
>>> data = np.array([1, 2, 3], dtype="float64")
>>> with Session() as ses:
... gmttype = ses._check_dtype_and_dim(data, ndim=1)
... gmttype == ses["GMT_DOUBLE"]
>>> with Session() as lib:
... gmttype = lib._check_dtype_and_dim(data, ndim=1)
... gmttype == lib["GMT_DOUBLE"]
True
>>> data = np.ones((5, 2), dtype="float32")
>>> with Session() as ses:
... gmttype = ses._check_dtype_and_dim(data, ndim=2)
... gmttype == ses["GMT_FLOAT"]
>>> with Session() as lib:
... gmttype = lib._check_dtype_and_dim(data, ndim=2)
... gmttype == lib["GMT_FLOAT"]
True
"""
# Check that the array has the given number of dimensions
# Check that the array has the given number of dimensions.
if array.ndim != ndim:
raise GMTInvalidInput(
f"Expected a numpy {ndim}-D array, got {array.ndim}-D."
)
msg = f"Expected a numpy {ndim}-D array, got {array.ndim}-D."
raise GMTInvalidInput(msg)

# Check that the array has a valid/known data type
if array.dtype.type not in DTYPES:
try:
if array.dtype.type is np.object_:
# Try to convert unknown object type to np.datetime64
array = array_to_datetime(array)
else:
raise ValueError
except ValueError as e:
raise GMTInvalidInput(
f"Unsupported numpy data type '{array.dtype.type}'."
) from e
return self[DTYPES[array.dtype.type]]
# For 1-D arrays, try to convert unknown object type to np.datetime64.
if ndim == 1 and array.dtype.type is np.object_:
with contextlib.suppress(ValueError):
array = array_to_datetime(array)

# 1-D arrays can be numeric or text, 2-D arrays can only be numeric.
valid_dtypes = DTYPES if ndim == 1 else DTYPES_NUMERIC
if (dtype := array.dtype.type) not in valid_dtypes:
msg = f"Unsupported numpy data type '{dtype}'."
raise GMTInvalidInput(msg)
return self[DTYPES[dtype]]

def put_vector(self, dataset: ctp.c_void_p, column: int, vector: np.ndarray):
r"""
Expand Down