Skip to content

Commit ad9e619

Browse files
committed
Add index accessor
1 parent 1118414 commit ad9e619

File tree

4 files changed

+84
-1
lines changed

4 files changed

+84
-1
lines changed

src/pandas_openscm/accessors/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import pandas as pd
3939

4040
from pandas_openscm.accessors.dataframe import PandasDataFrameOpenSCMAccessor
41+
from pandas_openscm.accessors.index import PandasIndexOpenSCMAccessor
4142
from pandas_openscm.accessors.series import PandasSeriesOpenSCMAccessor
4243

4344

@@ -73,4 +74,4 @@ def register_pandas_accessors(namespace: str = "openscm") -> None:
7374
PandasDataFrameOpenSCMAccessor
7475
)
7576
pd.api.extensions.register_series_accessor(namespace)(PandasSeriesOpenSCMAccessor)
76-
# pd.api.extensions.register_index_accessor(namespace)(PandasIndexOpenSCMAccessor)
77+
pd.api.extensions.register_index_accessor(namespace)(PandasIndexOpenSCMAccessor)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Accessor for [pd.Index][pandas.Index] (and sub-classes)
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
8+
9+
import pandas as pd
10+
11+
from pandas_openscm.index_manipulation import ensure_is_multiindex
12+
13+
if TYPE_CHECKING:
14+
# Hmm this is somehow not correct.
15+
# Figuring it out is a job for another day
16+
Idx = TypeVar("Idx", bound=pd.Index[Any])
17+
18+
19+
else:
20+
Idx = TypeVar("Idx")
21+
22+
23+
class PandasIndexOpenSCMAccessor(Generic[Idx]):
24+
"""
25+
[pd.Index][pandas.Index] accessor
26+
27+
For details, see
28+
[pandas' docs](https://pandas.pydata.org/docs/development/extending.html#registering-custom-accessors).
29+
"""
30+
31+
def __init__(self, index: Idx):
32+
"""
33+
Initialise
34+
35+
Parameters
36+
----------
37+
index
38+
[pd.Index][pandas.Index] to use via the accessor
39+
"""
40+
# It is possible to validate here.
41+
# However, it's probably better to do validation closer to the data use.
42+
self._index = index
43+
44+
def ensure_is_multiindex(self) -> pd.MultiIndex:
45+
"""
46+
Ensure that the index is a [pd.MultiIndex][pandas.MultiIndex]
47+
48+
Returns
49+
-------
50+
:
51+
`index` as a [pd.MultiIndex][pandas.MultiIndex]
52+
53+
If the index was already a [pd.MultiIndex][pandas.MultiIndex],
54+
this is a no-op.
55+
"""
56+
res = ensure_is_multiindex(self._index)
57+
58+
return res
59+
60+
def eim(self) -> pd.MultiIndex:
61+
"""
62+
Ensure that the index is a [pd.MultiIndex][pandas.MultiIndex]
63+
64+
Alias for [ensure_is_multiindex][pandas_openscm.index_manipulation.]
65+
66+
Returns
67+
-------
68+
:
69+
`index` as a [pd.MultiIndex][pandas.MultiIndex]
70+
71+
If the index was already a [pd.MultiIndex][pandas.MultiIndex],
72+
this is a no-op (although the value of copy is respected).
73+
"""
74+
return self.ensure_is_multiindex()

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ def setup_pandas_accessors() -> None:
4545
pd.Series._accessors.discard("openscm")
4646
if hasattr(pd.Series, "openscm"):
4747
del pd.Series.openscm
48+
49+
pd.Index._accessors.discard("openscm")
50+
if hasattr(pd.Index, "openscm"):
51+
del pd.Index.openscm

tests/integration/index_manipulation/test_integration_index_manipulation_ensure_index_is_multiindex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def test_ensure_is_multiindex_accessor_index(setup_pandas_accessors):
178178
start = pd.Index([1, 2, 3], name="id")
179179

180180
res = start.openscm.ensure_is_multiindex()
181+
res_short = start.openscm.eim()
182+
pd.testing.assert_index_equal(res, res_short)
181183

182184
assert isinstance(res, pd.MultiIndex)
183185

@@ -203,6 +205,8 @@ def test_ensure_is_multiindex_accessor_multiindex(setup_pandas_accessors):
203205
)
204206

205207
res = start.openscm.ensure_is_multiindex()
208+
res_short = start.openscm.eim()
209+
pd.testing.assert_index_equal(res, res_short)
206210

207211
# Same object returned
208212
assert id(start) == id(res)

0 commit comments

Comments
 (0)