Skip to content

Commit 8f22154

Browse files
committed
feat: make ColumnDropped dataframe-agnostic
1 parent 8cabda3 commit 8f22154

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ maintainers = [
2020
]
2121

2222
dependencies = [
23+
"narwhals>=0.7.16",
2324
"pandas>=1.1.5",
2425
"scikit-learn>=1.0",
2526
"importlib-metadata >= 1.0; python_version < '3.8'",
@@ -61,6 +62,7 @@ docs = [
6162
]
6263

6364
test-dep = [
65+
"polars",
6466
"pytest>=6.2.5",
6567
"pytest-xdist>=1.34.0",
6668
"pytest-cov>=2.6.1",
@@ -111,4 +113,3 @@ markers = [
111113
"formulaic: tests that require formulaic (deselect with '-m \"not formulaic\"')",
112114
"umap: tests that require umap (deselect with '-m \"not umap\"')"
113115
]
114-

sklego/preprocessing/pandastransformers.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import narwhals as nw
12
import pandas as pd
23
from sklearn.base import BaseEstimator, TransformerMixin
34
from sklearn.utils.validation import check_is_fitted
@@ -106,9 +107,9 @@ def fit(self, X, y=None):
106107
If dropping the specified columns would result in an empty output DataFrame.
107108
"""
108109
self.columns_ = as_list(self.columns)
109-
self._check_X_for_type(X)
110+
X = nw.from_native(X)
110111
self._check_column_names(X)
111-
self.feature_names_ = X.columns.drop(self.columns_).tolist()
112+
self.feature_names_ = [x for x in X.columns if x not in self.columns_]
112113
self._check_column_length()
113114
return self
114115

@@ -131,10 +132,10 @@ def transform(self, X):
131132
If `X` is not a `pd.DataFrame` object.
132133
"""
133134
check_is_fitted(self, ["feature_names_"])
134-
self._check_X_for_type(X)
135+
X = nw.from_native(X)
135136
if self.columns_:
136-
return X.drop(columns=self.columns_)
137-
return X
137+
return nw.to_native(X.drop(self.columns_))
138+
return nw.to_native(X)
138139

139140
def get_feature_names(self):
140141
"""Alias for `.feature_names_` attribute"""
@@ -151,12 +152,6 @@ def _check_column_names(self, X):
151152
if len(non_existent_columns) > 0:
152153
raise KeyError(f"{list(non_existent_columns)} column(s) not in DataFrame")
153154

154-
@staticmethod
155-
def _check_X_for_type(X):
156-
"""Checks if input of the Selector is of the required dtype"""
157-
if not isinstance(X, pd.DataFrame):
158-
raise TypeError("Provided variable X is not of type pandas.DataFrame")
159-
160155

161156
class PandasTypeSelector(BaseEstimator, TransformerMixin):
162157
"""The `PandasTypeSelector` transformer allows to select columns in a pandas DataFrame based on their type.

tests/test_preprocessing/test_columndropper.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pandas as pd
2+
import polars as pl
23
import pytest
3-
from pandas.testing import assert_frame_equal
4+
from pandas.testing import assert_frame_equal as pandas_assert_frame_equal
5+
from polars.testing import assert_frame_equal as polars_assert_frame_equal
46
from sklearn.pipeline import make_pipeline
57

68
from sklego.preprocessing import ColumnDropper
@@ -19,6 +21,19 @@ def df():
1921
)
2022

2123

24+
@pytest.fixture()
25+
def df_polars():
26+
return pl.DataFrame(
27+
{
28+
"a": [1, 2, 3, 4, 5, 6],
29+
"b": [10, 9, 8, 7, 6, 5],
30+
"c": ["a", "b", "a", "b", "c", "c"],
31+
"d": ["b", "a", "a", "b", "a", "b"],
32+
"e": [0, 1, 0, 1, 0, 1],
33+
}
34+
)
35+
36+
2237
def test_drop_two(df):
2338
result_df = ColumnDropper(["a", "b"]).fit_transform(df)
2439
expected_df = pd.DataFrame(
@@ -29,7 +44,7 @@ def test_drop_two(df):
2944
}
3045
)
3146

32-
assert_frame_equal(result_df, expected_df)
47+
pandas_assert_frame_equal(result_df, expected_df)
3348

3449

3550
def test_drop_one(df):
@@ -43,7 +58,7 @@ def test_drop_one(df):
4358
}
4459
)
4560

46-
assert_frame_equal(result_df, expected_df)
61+
pandas_assert_frame_equal(result_df, expected_df)
4762

4863

4964
def test_drop_all(df):
@@ -53,7 +68,7 @@ def test_drop_all(df):
5368

5469
def test_drop_none(df):
5570
result_df = ColumnDropper([]).fit_transform(df)
56-
assert_frame_equal(result_df, df)
71+
pandas_assert_frame_equal(result_df, df)
5772

5873

5974
def test_drop_not_in_frame(df):
@@ -73,10 +88,23 @@ def test_drop_one_in_pipeline(df):
7388
}
7489
)
7590

76-
assert_frame_equal(result_df, expected_df)
91+
pandas_assert_frame_equal(result_df, expected_df)
7792

7893

7994
def test_get_feature_names():
8095
df = pd.DataFrame({"a": [4, 5, 6], "b": ["4", "5", "6"]})
8196
transformer = ColumnDropper("a").fit(df)
8297
assert transformer.get_feature_names() == ["b"]
98+
99+
100+
def test_drop_two_polars(df_polars):
101+
result_df = ColumnDropper(["a", "b"]).fit_transform(df_polars)
102+
expected_df = pl.DataFrame(
103+
{
104+
"c": ["a", "b", "a", "b", "c", "c"],
105+
"d": ["b", "a", "a", "b", "a", "b"],
106+
"e": [0, 1, 0, 1, 0, 1],
107+
}
108+
)
109+
110+
polars_assert_frame_equal(result_df, expected_df)

0 commit comments

Comments
 (0)