diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 0640ba6ae20..42d74a815a2 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1747,6 +1747,8 @@ def virtualfile_to_dataset( vfname: str, output_type: Literal["pandas", "numpy", "file"] = "pandas", column_names: list[str] | None = None, + dtype: type | dict[str, type] | None = None, + index_col: str | int | None = None, ) -> pd.DataFrame | np.ndarray | None: """ Output a tabular dataset stored in a virtual file to a different format. @@ -1766,6 +1768,11 @@ def virtualfile_to_dataset( - ``"file"`` means the result was saved to a file and will return ``None``. column_names The column names for the :class:`pandas.DataFrame` output. + dtype + Data type for the columns of the :class:`pandas.DataFrame` output. Can be a + single type for all columns or a dictionary mapping column names to types. + index_col + Column to set as the index of the :class:`pandas.DataFrame` output. Returns ------- @@ -1854,13 +1861,13 @@ def virtualfile_to_dataset( return None # Read the virtual file as a GMT dataset and convert to pandas.DataFrame - result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe() + result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe( + column_names=column_names, + dtype=dtype, + index_col=index_col, + ) if output_type == "numpy": # numpy.ndarray output return result.to_numpy() - - # Assign column names - if column_names is not None: - result.columns = column_names return result # pandas.DataFrame output def extract_region(self): diff --git a/pygmt/datatypes/dataset.py b/pygmt/datatypes/dataset.py index a0d0547f3ca..274d2fee97c 100644 --- a/pygmt/datatypes/dataset.py +++ b/pygmt/datatypes/dataset.py @@ -143,7 +143,12 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801 ("hidden", ctp.c_void_p), ] - def to_dataframe(self) -> pd.DataFrame: + def to_dataframe( + self, + column_names: list[str] | None = None, + dtype: type | dict[str, type] | None = None, + index_col: str | int | None = None, + ) -> pd.DataFrame: """ Convert a _GMT_DATASET object to a :class:`pandas.DataFrame` object. @@ -151,6 +156,16 @@ def to_dataframe(self) -> pd.DataFrame: the same. The same column in all segments of all tables are concatenated. The trailing text column is also concatenated as a single string column. + Parameters + ---------- + column_names + A list of column names. + dtype + Data type. Can be a single type for all columns or a dictionary mapping + column names to types. + index_col + Column to set as index. + Returns ------- df @@ -211,5 +226,11 @@ def to_dataframe(self) -> pd.DataFrame: pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype()) ) - df = pd.concat(objs=vectors, axis=1) + df = pd.concat(objs=vectors, axis="columns") + if column_names is not None: # Assign column names + df.columns = column_names + if dtype is not None: + df = df.astype(dtype) + if index_col is not None: + df = df.set_index(index_col) return df diff --git a/pygmt/src/grdhisteq.py b/pygmt/src/grdhisteq.py index b0285e4e3d5..44d191a417e 100644 --- a/pygmt/src/grdhisteq.py +++ b/pygmt/src/grdhisteq.py @@ -238,18 +238,14 @@ def compute_bins( module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd) ) - result = lib.virtualfile_to_dataset( + return lib.virtualfile_to_dataset( vfname=vouttbl, output_type=output_type, column_names=["start", "stop", "bin_id"], + dtype={ + "start": np.float32, + "stop": np.float32, + "bin_id": np.uint32, + }, + index_col="bin_id" if output_type == "pandas" else None, ) - if output_type == "pandas": - result = result.astype( - { - "start": np.float32, - "stop": np.float32, - "bin_id": np.uint32, - } - ) - return result.set_index("bin_id") - return result