Skip to content

Commit 0773db9

Browse files
committed
feat: Make InformationFilter dataframe-agnostic
1 parent 94cf506 commit 0773db9

File tree

2 files changed

+50
-48
lines changed

2 files changed

+50
-48
lines changed

sklego/preprocessing/projections.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import narwhals as nw
12
import numpy as np
2-
import pandas as pd
33
from sklearn.base import BaseEstimator, TransformerMixin
44
from sklearn.utils import check_array
55
from sklearn.utils.validation import check_is_fitted
@@ -181,16 +181,17 @@ def __init__(self, columns, alpha=1):
181181

182182
def _check_coltype(self, X):
183183
"""Check if the `columns` type(s) are compatible with `X` type."""
184+
X_ = nw.from_native(X, strict=False, eager_only=True)
184185
for col in as_list(self.columns):
185186
if isinstance(col, str):
186-
if isinstance(X, np.ndarray):
187+
if isinstance(X_, np.ndarray):
187188
raise ValueError(f"column {col} is a string but datatype receive is numpy.")
188-
if isinstance(X, pd.DataFrame):
189-
if col not in X.columns:
190-
raise ValueError(f"column {col} is not in {X.columns}")
189+
if isinstance(X_, nw.DataFrame):
190+
if col not in X_.columns:
191+
raise ValueError(f"column {col} is not in {X_.columns}")
191192
if isinstance(col, int):
192-
if col not in range(np.atleast_2d(np.array(X)).shape[1]):
193-
raise ValueError(f"column {col} is out of bounds for input shape {X.shape}")
193+
if col not in range(np.atleast_2d(np.array(X_)).shape[1]):
194+
raise ValueError(f"column {col} is out of bounds for input shape {X_.shape}")
194195

195196
def _col_idx(self, X, name):
196197
"""Get the column index of a column name."""

tests/test_preprocessing/test_informationfilter.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import narwhals as nw
12
import numpy as np
23
import pandas as pd
4+
import polars as pl
35
import pytest
46
from sklearn.datasets import fetch_openml
57
from sklearn.linear_model import LinearRegression
@@ -49,57 +51,56 @@ def test_alpha_param1():
4951
assert np.isclose(ifilter.fit_transform(X), X_removed).all()
5052

5153

52-
def test_alpha_param2():
54+
@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
55+
def test_alpha_param2(frame_func):
5356
X, y = fetch_openml(data_id=531, return_X_y=True, as_frame=False, parser="liac-arff")
54-
df = pd.DataFrame(
55-
X,
56-
columns=[
57-
"crim",
58-
"zn",
59-
"indus",
60-
"chas",
61-
"nox",
62-
"rm",
63-
"age",
64-
"dis",
65-
"rad",
66-
"tax",
67-
"ptratio",
68-
"b",
69-
"lstat",
70-
],
71-
)
57+
cols = [
58+
"crim",
59+
"zn",
60+
"indus",
61+
"chas",
62+
"nox",
63+
"rm",
64+
"age",
65+
"dis",
66+
"rad",
67+
"tax",
68+
"ptratio",
69+
"b",
70+
"lstat",
71+
]
72+
df = frame_func(dict(zip(cols, X.T)))
7273
ifilter = InformationFilter(columns=["b", "lstat"], alpha=0.0)
73-
X_removed = df.drop(columns=["b", "lstat"]).values
74+
X_removed = nw.from_native(df).drop(["b", "lstat"]).to_numpy()
7475
assert np.isclose(ifilter.fit_transform(df), X_removed).all()
7576

7677

77-
def test_output_orthogonal_pandas():
78+
@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
79+
def test_output_orthogonal_frame(frame_func):
7880
X, y = fetch_openml(data_id=531, return_X_y=True, as_frame=False, parser="liac-arff")
79-
df = pd.DataFrame(
80-
X,
81-
columns=[
82-
"crim",
83-
"zn",
84-
"indus",
85-
"chas",
86-
"nox",
87-
"rm",
88-
"age",
89-
"dis",
90-
"rad",
91-
"tax",
92-
"ptratio",
93-
"b",
94-
"lstat",
95-
],
96-
)
81+
cols = [
82+
"crim",
83+
"zn",
84+
"indus",
85+
"chas",
86+
"nox",
87+
"rm",
88+
"age",
89+
"dis",
90+
"rad",
91+
"tax",
92+
"ptratio",
93+
"b",
94+
"lstat",
95+
]
96+
df = frame_func(dict(zip(cols, X.T)))
9797
X_fair = InformationFilter(columns=["b", "lstat"]).fit_transform(df)
9898
assert all([(c * df["b"]).sum() < 1e-5 for c in X_fair.T])
9999
assert all([(c * df["lstat"]).sum() < 1e-5 for c in X_fair.T])
100100

101101

102-
def test_output_orthogonal_general_cols():
102+
@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
103+
def test_output_orthogonal_general_cols(frame_func):
103104
X, y = fetch_openml(data_id=531, return_X_y=True, as_frame=False, parser="liac-arff")
104105
cols = [
105106
"crim",
@@ -116,7 +117,7 @@ def test_output_orthogonal_general_cols():
116117
"b",
117118
"lstat",
118119
]
119-
df = pd.DataFrame(X, columns=cols)
120+
df = frame_func(dict(zip(cols, X.T)))
120121
for col in cols:
121122
X_fair = InformationFilter(columns=col).fit_transform(df)
122123
assert all([(c * df[col]).sum() < 1e-5 for c in X_fair.T])

0 commit comments

Comments
 (0)