Skip to content
6 changes: 6 additions & 0 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,6 +1746,7 @@ def virtualfile_to_dataset(
self,
vfname: str,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
header: int | None = None,
column_names: list[str] | None = None,
dtype: type | dict[str, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -1766,6 +1767,10 @@ def virtualfile_to_dataset(
- ``"pandas"`` will return a :class:`pandas.DataFrame` object.
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
- ``"file"`` means the result was saved to a file and will return ``None``.
header
Row number containing column names for the :class:`pandas.DataFrame` output.
``header=None`` means not to parse the column names from table header.
Ignored if the row number is larger than the number of headers in the table.
column_names
The column names for the :class:`pandas.DataFrame` output.
dtype
Expand Down Expand Up @@ -1862,6 +1867,7 @@ def virtualfile_to_dataset(

# Read the virtual file as a GMT dataset and convert to pandas.DataFrame
result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe(
header=header,
column_names=column_names,
dtype=dtype,
index_col=index_col,
Expand Down
26 changes: 20 additions & 6 deletions pygmt/datatypes/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
>>> with GMTTempFile(suffix=".txt") as tmpfile:
... # Prepare the sample data file
... with Path(tmpfile.name).open(mode="w") as fp:
... print("# x y z name", file=fp)
... print(">", file=fp)
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
Expand All @@ -42,7 +43,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
... print(ds.min[: ds.n_columns], ds.max[: ds.n_columns])
... # The table
... tbl = ds.table[0].contents
... print(tbl.n_columns, tbl.n_segments, tbl.n_records)
... print(tbl.n_columns, tbl.n_segments, tbl.n_records, tbl.n_headers)
... print(tbl.header[: tbl.n_headers])
... print(tbl.min[: tbl.n_columns], ds.max[: tbl.n_columns])
... for i in range(tbl.n_segments):
... seg = tbl.segment[i].contents
Expand All @@ -51,7 +53,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
... print(seg.text[: seg.n_rows])
1 3 2
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
3 2 4
3 2 4 1
[b'x y z name']
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
[1.0, 4.0]
[2.0, 5.0]
Expand Down Expand Up @@ -144,8 +147,9 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801
("hidden", ctp.c_void_p),
]

def to_dataframe(
def to_dataframe( # noqa: PLR0912
self,
header: int | None = None,
column_names: pd.Index | None = None,
dtype: type | Mapping[Any, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -164,6 +168,10 @@ def to_dataframe(
----------
column_names
A list of column names.
header
Row number containing column names. ``header=None`` means not to parse the
column names from table header. Ignored if the row number is larger than the
number of headers in the table.
dtype
Data type. Can be a single type for all columns or a dictionary mapping
column names to types.
Expand All @@ -184,6 +192,7 @@ def to_dataframe(
>>> with GMTTempFile(suffix=".txt") as tmpfile:
... # prepare the sample data file
... with Path(tmpfile.name).open(mode="w") as fp:
... print("# col1 col2 col3 colstr", file=fp)
... print(">", file=fp)
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
Expand All @@ -194,9 +203,9 @@ def to_dataframe(
... with lib.virtualfile_out(kind="dataset") as vouttbl:
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
... df = ds.contents.to_dataframe()
... df = ds.contents.to_dataframe(header=0)
>>> df
0 1 2 3
col1 col2 col3 colstr
0 1.0 2.0 3.0 TEXT1 TEXT23
1 4.0 5.0 6.0 TEXT4 TEXT567
2 7.0 8.0 9.0 TEXT8 TEXT90
Expand Down Expand Up @@ -230,14 +239,19 @@ def to_dataframe(
pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype())
)

if header is not None:
tbl = self.table[0].contents # Use the first table!
if header < tbl.n_headers:
column_names = tbl.header[header].decode().split()

if len(vectors) == 0:
# Return an empty DataFrame if no columns are found.
df = pd.DataFrame(columns=column_names)
else:
# Create a DataFrame object by concatenating multiple columns
df = pd.concat(objs=vectors, axis="columns")
if column_names is not None: # Assign column names
df.columns = column_names
df.columns = column_names[: df.shape[1]]
if dtype is not None: # Set dtype for the whole dataset or individual columns
df = df.astype(dtype)
if index_col is not None: # Use a specific column as index
Expand Down
61 changes: 59 additions & 2 deletions pygmt/tests/test_datatypes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def dataframe_from_pandas(filepath_or_buffer, sep=r"\s+", comment="#", header=No
return df


def dataframe_from_gmt(fname):
def dataframe_from_gmt(fname, **kwargs):
"""
Read tabular data as pandas.DataFrame using GMT virtual file.
"""
with Session() as lib:
with lib.virtualfile_out(kind="dataset") as vouttbl:
lib.call_module("read", f"{fname} {vouttbl} -Td")
df = lib.virtualfile_to_dataset(vfname=vouttbl)
df = lib.virtualfile_to_dataset(vfname=vouttbl, **kwargs)
return df


Expand Down Expand Up @@ -81,3 +81,60 @@ def test_dataset_empty():
assert df.empty # Empty DataFrame
expected_df = dataframe_from_pandas(tmpfile.name)
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header():
"""
Test parsing column names from dataset header.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

# Parse columne names from the first header line
df = dataframe_from_gmt(tmpfile.name, header=0)
assert df.columns.tolist() == ["lon", "lat", "z", "text"]
# pd.read_csv() can't parse the header line with a leading '#'.
# So, we need to skip the header line and manually set the column names.
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
expected_df.columns = df.columns.tolist()
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header_greater_than_nheaders():
"""
Test passing a header line number that is greater than the number of header lines.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

# Parse column names from the second header line.
df = dataframe_from_gmt(tmpfile.name, header=1)
# There is only one header line, so the column names should be default.
assert df.columns.tolist() == [0, 1, 2, 3]
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header_too_many_names():
"""
Test passing a header line with more column names than the number of columns.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text1 text2", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

df = dataframe_from_gmt(tmpfile.name, header=0)
assert df.columns.tolist() == ["lon", "lat", "z", "text1"]
# pd.read_csv() can't parse the header line with a leading '#'.
# So, we need to skip the header line and manually set the column names.
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
expected_df.columns = df.columns.tolist()
pd.testing.assert_frame_equal(df, expected_df)