@@ -85,16 +85,16 @@ def test_load_timeseries_csv_lower_column_names(tmp_path, lower_column_names):
8585@pytest .mark .parametrize (
8686 # Column type and value type are not the same
8787 # because columns are held as numpy arrays.
88- "out_column_type , exp_column_value_type" ,
88+ "out_columns_type , exp_column_value_type" ,
8989 (
9090 (int , np .int64 ),
9191 (float , np .float64 ),
9292 (np .float64 , np .float64 ),
9393 (np .float32 , np .float32 ),
9494 ),
9595)
96- def test_load_timeseries_csv_basic_out_column_type (
97- tmp_path , out_column_type , exp_column_value_type
96+ def test_load_timeseries_csv_basic_out_columns_type (
97+ tmp_path , out_columns_type , exp_column_value_type
9898):
9999 out_path = tmp_path / "test_load_timeseries_csv.csv"
100100
@@ -113,13 +113,46 @@ def test_load_timeseries_csv_basic_out_column_type(
113113 index_columns = ["variable" , "scenario" , "run" , "unit" ]
114114
115115 loaded = load_timeseries_csv (
116- out_path , index_columns = index_columns , out_column_type = out_column_type
116+ out_path , index_columns = index_columns , out_columns_type = out_columns_type
117117 )
118118
119119 assert loaded .index .names == index_columns
120120 assert all (isinstance (c , exp_column_value_type ) for c in loaded .columns .values )
121121
122122
123+ @pytest .mark .parametrize (
124+ "out_columns_name, exp_columns_name" ,
125+ (
126+ (None , None ),
127+ ("hi" , "hi" ),
128+ ("time" , "time" ),
129+ ),
130+ )
131+ def test_load_timeseries_csv_basic_out_columns_name (
132+ tmp_path , out_columns_name , exp_columns_name
133+ ):
134+ out_path = tmp_path / "test_load_timeseries_csv.csv"
135+
136+ timepoints = np .arange (1990.0 , 2010.0 + 1.0 , dtype = int )
137+ start = create_test_df (
138+ variables = [(f"variable_{ i } " , "Mt" ) for i in range (5 )],
139+ n_scenarios = 3 ,
140+ n_runs = 6 ,
141+ timepoints = timepoints ,
142+ )
143+ assert start .columns .name is None
144+
145+ start .to_csv (out_path )
146+
147+ index_columns = ["variable" , "scenario" , "run" , "unit" ]
148+
149+ loaded = load_timeseries_csv (
150+ out_path , index_columns = index_columns , out_columns_name = out_columns_name
151+ )
152+
153+ assert loaded .columns .name == exp_columns_name
154+
155+
123156@pytest .mark .xfail (reason = "Not implemented" )
124157def test_load_timeseries_csv_infer_index_cols (tmp_path ):
125158 # Suggested cases here:
0 commit comments