Skip to content

Commit ee7da35

Browse files
committed
Add test of updated load_timeseries_csv
1 parent 740069a commit ee7da35

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

src/pandas_openscm/io.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def load_timeseries_csv(
1313
fp: Path,
1414
lower_column_names: bool = True,
1515
index_columns: list[str] | None = None,
16-
out_column_type: type | None = None,
16+
out_columns_type: type | None = None,
1717
out_columns_name: str | None = None,
1818
) -> pd.DataFrame:
1919
"""
@@ -43,7 +43,7 @@ def load_timeseries_csv(
4343
In future, if not provided, we will try and infer the columns
4444
based on whether they look like time columns or not.
4545
46-
out_column_type
46+
out_columns_type
4747
The type to apply to the output columns that are not part of the index.
4848
4949
If not supplied, the raw type returned by pandas is returned.
@@ -73,8 +73,8 @@ def load_timeseries_csv(
7373

7474
out = out.set_index(index_columns)
7575

76-
if out_column_type is not None:
77-
out.columns = out.columns.astype(out_column_type)
76+
if out_columns_type is not None:
77+
out.columns = out.columns.astype(out_columns_type)
7878

7979
if out_columns_name is not None:
8080
out = out.rename_axis(out_columns_name, axis="columns")

tests/integration/test_io.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ def test_load_timeseries_csv_lower_column_names(tmp_path, lower_column_names):
8585
@pytest.mark.parametrize(
8686
# Column type and value type are not the same
8787
# because columns are held as numpy arrays.
88-
"out_column_type, exp_column_value_type",
88+
"out_columns_type, exp_column_value_type",
8989
(
9090
(int, np.int64),
9191
(float, np.float64),
9292
(np.float64, np.float64),
9393
(np.float32, np.float32),
9494
),
9595
)
96-
def test_load_timeseries_csv_basic_out_column_type(
97-
tmp_path, out_column_type, exp_column_value_type
96+
def test_load_timeseries_csv_basic_out_columns_type(
97+
tmp_path, out_columns_type, exp_column_value_type
9898
):
9999
out_path = tmp_path / "test_load_timeseries_csv.csv"
100100

@@ -113,13 +113,46 @@ def test_load_timeseries_csv_basic_out_column_type(
113113
index_columns = ["variable", "scenario", "run", "unit"]
114114

115115
loaded = load_timeseries_csv(
116-
out_path, index_columns=index_columns, out_column_type=out_column_type
116+
out_path, index_columns=index_columns, out_columns_type=out_columns_type
117117
)
118118

119119
assert loaded.index.names == index_columns
120120
assert all(isinstance(c, exp_column_value_type) for c in loaded.columns.values)
121121

122122

123+
@pytest.mark.parametrize(
124+
"out_columns_name, exp_columns_name",
125+
(
126+
(None, None),
127+
("hi", "hi"),
128+
("time", "time"),
129+
),
130+
)
131+
def test_load_timeseries_csv_basic_out_columns_name(
132+
tmp_path, out_columns_name, exp_columns_name
133+
):
134+
out_path = tmp_path / "test_load_timeseries_csv.csv"
135+
136+
timepoints = np.arange(1990.0, 2010.0 + 1.0, dtype=int)
137+
start = create_test_df(
138+
variables=[(f"variable_{i}", "Mt") for i in range(5)],
139+
n_scenarios=3,
140+
n_runs=6,
141+
timepoints=timepoints,
142+
)
143+
assert start.columns.name is None
144+
145+
start.to_csv(out_path)
146+
147+
index_columns = ["variable", "scenario", "run", "unit"]
148+
149+
loaded = load_timeseries_csv(
150+
out_path, index_columns=index_columns, out_columns_name=out_columns_name
151+
)
152+
153+
assert loaded.columns.name == exp_columns_name
154+
155+
123156
@pytest.mark.xfail(reason="Not implemented")
124157
def test_load_timeseries_csv_infer_index_cols(tmp_path):
125158
# Suggested cases here:

0 commit comments

Comments
 (0)