Skip to content

Commit 47ffad6

Browse files
committed
Add failing tests
1 parent d896575 commit 47ffad6

File tree

2 files changed

+89
-9
lines changed

2 files changed

+89
-9
lines changed

tests/integration/index_manipulation/test_integration_index_manipulation_update_levels.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
import pytest
1212

1313
from pandas_openscm.index_manipulation import update_index_levels_func, update_levels
14+
from pandas_openscm.testing import check_result, convert_to_desired_type
15+
16+
pobj_type = pytest.mark.parametrize(
17+
"pobj_type",
18+
("DataFrame", "Series"),
19+
)
20+
"""
21+
Parameterisation to use to check handling of both DataFrame and Series
22+
"""
1423

1524

1625
@pytest.mark.parametrize(
@@ -190,8 +199,34 @@ def update_func(in_v: int) -> int:
190199
updates, remove_unused_levels=False
191200
)
192201

202+
# Same thing but from a Series
203+
start_series = start_df[2020]
204+
205+
res_series = update_index_levels_func(start_series.iloc[:-1], updates=updates)
206+
207+
exp_series = pd.Series(np.zeros(exp.shape[0]), name=2020, index=exp)
208+
209+
pd.testing.assert_series_equal(res_series, exp_series)
210+
with exp_error_no_removal:
211+
update_index_levels_func(
212+
start_series.iloc[:-1],
213+
updates=updates,
214+
remove_unused_levels=False,
215+
)
216+
217+
# Lastly, test the accessor
218+
pd.testing.assert_series_equal(
219+
start_series.iloc[:-1].openscm.update_index_levels(updates),
220+
exp_series,
221+
)
222+
with exp_error_no_removal:
223+
start_series.iloc[:-1].openscm.update_index_levels(
224+
updates, remove_unused_levels=False
225+
)
226+
193227

194-
def test_accessor(setup_pandas_accessors):
228+
@pobj_type
229+
def test_accessor(setup_pandas_accessors, pobj_type):
195230
start = pd.DataFrame(
196231
np.arange(2 * 4).reshape((4, 2)),
197232
columns=[2010, 2020],
@@ -205,6 +240,7 @@ def test_accessor(setup_pandas_accessors):
205240
names=["scenario", "variable", "unit", "run_id"],
206241
),
207242
)
243+
convert_to_desired_type(start, pobj_type)
208244

209245
updates = {
210246
"variable": lambda x: x.replace("v", "vv"),
@@ -224,17 +260,20 @@ def test_accessor(setup_pandas_accessors):
224260
names=["scenario", "variable", "unit", "run_id"],
225261
),
226262
)
263+
exp = convert_to_desired_type(exp, pobj_type)
227264

228265
res = start.openscm.update_index_levels(updates)
229-
pd.testing.assert_frame_equal(res, exp)
266+
check_result(res, exp)
230267

231268
# Test function too
232269
res = update_index_levels_func(start, updates)
233-
pd.testing.assert_frame_equal(res, exp)
270+
check_result(res, exp)
234271

235272

236-
def test_accessor_not_multiindex(setup_pandas_accessors):
273+
@pobj_type
274+
def test_accessor_not_multiindex(setup_pandas_accessors, pobj_type):
237275
start = pd.DataFrame(np.arange(2 * 4).reshape((4, 2)))
276+
start = convert_to_desired_type(start, pobj_type)
238277

239278
error_msg = re.escape(
240279
"This function is only intended to be used "

tests/integration/index_manipulation/test_integration_index_manipulation_update_levels_from_other.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
update_index_levels_from_other_func,
1515
update_levels_from_other,
1616
)
17+
from pandas_openscm.testing import check_result, convert_to_desired_type
18+
19+
pobj_type = pytest.mark.parametrize(
20+
"pobj_type",
21+
("DataFrame", "Series"),
22+
)
23+
"""
24+
Parameterisation to use to check handling of both DataFrame and Series
25+
"""
1726

1827

1928
@pytest.mark.parametrize(
@@ -273,7 +282,7 @@ def test_update_levels_from_other_missing_levels(sources, exp):
273282
update_levels_from_other(start, update_sources=update_sources)
274283

275284

276-
def test_doesnt_trip_over_droped_levels(setup_pandas_accessors):
285+
def test_doesnt_trip_over_dropped_levels(setup_pandas_accessors):
277286
def update_func(in_v: int) -> int:
278287
if in_v < 0:
279288
msg = f"Value must be greater than zero, received {in_v}"
@@ -347,8 +356,36 @@ def update_func(in_v: int) -> int:
347356
update_sources, remove_unused_levels=False
348357
)
349358

359+
# Same thing but from a Series
360+
start_series = start_df[2020]
361+
362+
res_series = update_index_levels_from_other_func(
363+
start_series.iloc[:-1], update_sources=update_sources
364+
)
365+
366+
exp_series = pd.Series(np.zeros(exp.shape[0]), name=2020, index=exp)
367+
368+
pd.testing.assert_series_equal(res_series, exp_series)
369+
with exp_error_no_removal:
370+
update_index_levels_from_other_func(
371+
start_series.iloc[:-1],
372+
update_sources=update_sources,
373+
remove_unused_levels=False,
374+
)
375+
376+
# Lastly, test the accessor
377+
pd.testing.assert_series_equal(
378+
start_series.iloc[:-1].openscm.update_index_levels_from_other(update_sources),
379+
exp_series,
380+
)
381+
with exp_error_no_removal:
382+
start_series.iloc[:-1].openscm.update_index_levels_from_other(
383+
update_sources, remove_unused_levels=False
384+
)
385+
350386

351-
def test_accessor(setup_pandas_accessors):
387+
@pobj_type
388+
def test_accessor(setup_pandas_accessors, pobj_type):
352389
start = pd.DataFrame(
353390
np.arange(2 * 4).reshape((4, 2)),
354391
columns=[2010, 2020],
@@ -362,6 +399,7 @@ def test_accessor(setup_pandas_accessors):
362399
names=["scenario", "variable", "unit", "run_id"],
363400
),
364401
)
402+
convert_to_desired_type(start, pobj_type)
365403

366404
update_sources = {
367405
# callables single source
@@ -408,17 +446,20 @@ def test_accessor(setup_pandas_accessors):
408446
],
409447
),
410448
)
449+
exp = convert_to_desired_type(exp, pobj_type)
411450

412451
res = start.openscm.update_index_levels_from_other(update_sources)
413-
pd.testing.assert_frame_equal(res, exp)
452+
check_result(res, exp)
414453

415454
# Test function too
416455
res = update_index_levels_from_other_func(start, update_sources)
417-
pd.testing.assert_frame_equal(res, exp)
456+
check_result(res, exp)
418457

419458

420-
def test_accessor_not_multiindex(setup_pandas_accessors):
459+
@pobj_type
460+
def test_accessor_not_multiindex(setup_pandas_accessors, pobj_type):
421461
start = pd.DataFrame(np.arange(2 * 4).reshape((4, 2)))
462+
start = convert_to_desired_type(start, pobj_type)
422463

423464
error_msg = re.escape(
424465
"This function is only intended to be used "

0 commit comments

Comments
 (0)