1414 update_index_levels_from_other_func ,
1515 update_levels_from_other ,
1616)
17+ from pandas_openscm .testing import check_result , convert_to_desired_type
18+
19+ pobj_type = pytest .mark .parametrize (
20+ "pobj_type" ,
21+ ("DataFrame" , "Series" ),
22+ )
23+ """
24+ Parameterisation to use to check handling of both DataFrame and Series
25+ """
1726
1827
1928@pytest .mark .parametrize (
@@ -273,7 +282,7 @@ def test_update_levels_from_other_missing_levels(sources, exp):
273282 update_levels_from_other (start , update_sources = update_sources )
274283
275284
276- def test_doesnt_trip_over_droped_levels (setup_pandas_accessors ):
285+ def test_doesnt_trip_over_dropped_levels (setup_pandas_accessors ):
277286 def update_func (in_v : int ) -> int :
278287 if in_v < 0 :
279288 msg = f"Value must be greater than zero, received { in_v } "
@@ -347,8 +356,36 @@ def update_func(in_v: int) -> int:
347356 update_sources , remove_unused_levels = False
348357 )
349358
359+ # Same thing but from a Series
360+ start_series = start_df [2020 ]
361+
362+ res_series = update_index_levels_from_other_func (
363+ start_series .iloc [:- 1 ], update_sources = update_sources
364+ )
365+
366+ exp_series = pd .Series (np .zeros (exp .shape [0 ]), name = 2020 , index = exp )
367+
368+ pd .testing .assert_series_equal (res_series , exp_series )
369+ with exp_error_no_removal :
370+ update_index_levels_from_other_func (
371+ start_series .iloc [:- 1 ],
372+ update_sources = update_sources ,
373+ remove_unused_levels = False ,
374+ )
375+
376+ # Lastly, test the accessor
377+ pd .testing .assert_series_equal (
378+ start_series .iloc [:- 1 ].openscm .update_index_levels_from_other (update_sources ),
379+ exp_series ,
380+ )
381+ with exp_error_no_removal :
382+ start_series .iloc [:- 1 ].openscm .update_index_levels_from_other (
383+ update_sources , remove_unused_levels = False
384+ )
385+
350386
351- def test_accessor (setup_pandas_accessors ):
387+ @pobj_type
388+ def test_accessor (setup_pandas_accessors , pobj_type ):
352389 start = pd .DataFrame (
353390 np .arange (2 * 4 ).reshape ((4 , 2 )),
354391 columns = [2010 , 2020 ],
@@ -362,6 +399,7 @@ def test_accessor(setup_pandas_accessors):
362399 names = ["scenario" , "variable" , "unit" , "run_id" ],
363400 ),
364401 )
402+ convert_to_desired_type (start , pobj_type )
365403
366404 update_sources = {
367405 # callables single source
@@ -408,17 +446,20 @@ def test_accessor(setup_pandas_accessors):
408446 ],
409447 ),
410448 )
449+ exp = convert_to_desired_type (exp , pobj_type )
411450
412451 res = start .openscm .update_index_levels_from_other (update_sources )
413- pd . testing . assert_frame_equal (res , exp )
452+ check_result (res , exp )
414453
415454 # Test function too
416455 res = update_index_levels_from_other_func (start , update_sources )
417- pd . testing . assert_frame_equal (res , exp )
456+ check_result (res , exp )
418457
419458
420- def test_accessor_not_multiindex (setup_pandas_accessors ):
459+ @pobj_type
460+ def test_accessor_not_multiindex (setup_pandas_accessors , pobj_type ):
421461 start = pd .DataFrame (np .arange (2 * 4 ).reshape ((4 , 2 )))
462+ start = convert_to_desired_type (start , pobj_type )
422463
423464 error_msg = re .escape (
424465 "This function is only intended to be used "
0 commit comments