diff --git a/pygmt/datatypes/dataset.py b/pygmt/datatypes/dataset.py index 274d2fee97c..7a61b7f3d91 100644 --- a/pygmt/datatypes/dataset.py +++ b/pygmt/datatypes/dataset.py @@ -3,7 +3,8 @@ """ import ctypes as ctp -from typing import ClassVar +from collections.abc import Mapping +from typing import Any, ClassVar import numpy as np import pandas as pd @@ -13,8 +14,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801 """ GMT dataset structure for holding multiple tables (files). - This class is only meant for internal use by PyGMT and is not exposed to users. - See the GMT source code gmt_resources.h for the original C struct definitions. + This class is only meant for internal use and is not exposed to users. See the GMT + source code ``gmt_resources.h`` for the original C struct definitions. Examples -------- @@ -145,8 +146,8 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801 def to_dataframe( self, - column_names: list[str] | None = None, - dtype: type | dict[str, type] | None = None, + column_names: pd.Index | None = None, + dtype: type | Mapping[Any, type] | None = None, index_col: str | int | None = None, ) -> pd.DataFrame: """ @@ -156,6 +157,9 @@ def to_dataframe( the same. The same column in all segments of all tables are concatenated. The trailing text column is also concatenated as a single string column. + If the object contains no data, an empty DataFrame will be returned (with the + column names and dtypes set if provided). + Parameters ---------- column_names @@ -200,8 +204,8 @@ def to_dataframe( >>> df.dtypes.to_list() [dtype('float64'), dtype('float64'), dtype('float64'), string[python]] """ - # Deal with numeric columns vectors = [] + # Deal with numeric columns for icol in range(self.n_columns): colvector = [] for itbl in range(self.n_tables): @@ -226,11 +230,16 @@ def to_dataframe( pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype()) ) - df = pd.concat(objs=vectors, axis="columns") - if column_names is not None: # Assign column names - df.columns = column_names - if dtype is not None: + if len(vectors) == 0: + # Return an empty DataFrame if no columns are found. + df = pd.DataFrame(columns=column_names) + else: + # Create a DataFrame object by concatenating multiple columns + df = pd.concat(objs=vectors, axis="columns") + if column_names is not None: # Assign column names + df.columns = column_names + if dtype is not None: # Set dtype for the whole dataset or individual columns df = df.astype(dtype) - if index_col is not None: + if index_col is not None: # Use a specific column as index df = df.set_index(index_col) return df diff --git a/pygmt/tests/test_datatypes_dataset.py b/pygmt/tests/test_datatypes_dataset.py new file mode 100644 index 00000000000..7861b6b3119 --- /dev/null +++ b/pygmt/tests/test_datatypes_dataset.py @@ -0,0 +1,83 @@ +""" +Tests for GMT_DATASET data type. +""" + +from pathlib import Path + +import pandas as pd +import pytest +from pygmt.clib import Session +from pygmt.helpers import GMTTempFile + + +def dataframe_from_pandas(filepath_or_buffer, sep=r"\s+", comment="#", header=None): + """ + Read tabular data as pandas.DataFrame object using pandas.read_csv(). + + The parameters have the same meaning as in ``pandas.read_csv()``. + """ + try: + df = pd.read_csv(filepath_or_buffer, sep=sep, comment=comment, header=header) + except pd.errors.EmptyDataError: + # Return an empty DataFrame if the file contains no data + return pd.DataFrame() + + # By default, pandas reads text strings with whitespaces as multiple columns, but + # GMT concatenates all trailing text as a single string column. Need do find all + # string columns (with dtype="object") and combine them into a single string column. + string_columns = df.select_dtypes(include=["object"]).columns + if len(string_columns) > 1: + df[string_columns[0]] = df[string_columns].apply(lambda x: " ".join(x), axis=1) + df = df.drop(string_columns[1:], axis=1) + # Convert 'object' to 'string' type + df = df.convert_dtypes( + convert_string=True, + convert_integer=False, + convert_boolean=False, + convert_floating=False, + ) + return df + + +def dataframe_from_gmt(fname): + """ + Read tabular data as pandas.DataFrame using GMT virtual file. + """ + with Session() as lib: + with lib.virtualfile_out(kind="dataset") as vouttbl: + lib.call_module("read", f"{fname} {vouttbl} -Td") + df = lib.virtualfile_to_dataset(vfname=vouttbl) + return df + + +@pytest.mark.benchmark +def test_dataset(): + """ + Test the basic functionality of GMT_DATASET. + """ + with GMTTempFile(suffix=".txt") as tmpfile: + with Path(tmpfile.name).open(mode="w") as fp: + print(">", file=fp) + print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp) + print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp) + print(">", file=fp) + print("7.0 8.0 9.0 TEXT8 TEXT90", file=fp) + print("10.0 11.0 12.0 TEXT123 TEXT456789", file=fp) + + df = dataframe_from_gmt(tmpfile.name) + expected_df = dataframe_from_pandas(tmpfile.name, comment=">") + pd.testing.assert_frame_equal(df, expected_df) + + +def test_dataset_empty(): + """ + Make sure that an empty DataFrame is returned if a file contains no data. + """ + with GMTTempFile(suffix=".txt") as tmpfile: + with Path(tmpfile.name).open(mode="w") as fp: + print("# This is a comment line.", file=fp) + + df = dataframe_from_gmt(tmpfile.name) + assert df.empty # Empty DataFrame + expected_df = dataframe_from_pandas(tmpfile.name) + pd.testing.assert_frame_equal(df, expected_df)