Skip to content

Commit 94cf506

Browse files
authored
Make RegressionOutlier dataframe-agnostic (#665)
* make regression outlier df-agnostic * need to use eager-only for this one * pass native to check_array * remove cudf, link to check_X_y
1 parent 28c102b commit 94cf506

2 files changed

Lines changed: 31 additions & 9 deletions

File tree

sklego/meta/regression_outlier_detector.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import narwhals as nw
12
import numpy as np
2-
import pandas as pd
33
from sklearn.base import BaseEstimator, OutlierMixin
44
from sklearn.utils.validation import check_array, check_is_fitted
55

@@ -11,8 +11,11 @@ class RegressionOutlierDetector(BaseEstimator, OutlierMixin):
1111
----------
1212
model : scikit-learn compatible regression model
1313
A regression model that will be used for prediction.
14-
column : int
15-
The index of the target column to predict in the input data.
14+
column : int | str
15+
This should be:
16+
17+
- The index of the target column to predict in the input data, when the input is an array.
18+
- The name of the target column to predict in the input data, when the input is a dataframe.
1619
lower : float, default=2.0
1720
Lower threshold for outlier detection. The method used for detection depends on the `method` parameter.
1821
upper : float, default=2.0
@@ -32,6 +35,21 @@ class RegressionOutlierDetector(BaseEstimator, OutlierMixin):
3235
The standard deviation of the differences between true and predicted values.
3336
idx_ : int
3437
The index of the target column in the input data.
38+
39+
Notes
40+
-----
41+
Native cross-dataframe support is achieved using
42+
[Narwhals](https://narwhals-dev.github.io/narwhals/){:target="_blank"}.
43+
Supported dataframes are:
44+
45+
- pandas
46+
- Polars (eager)
47+
- Modin
48+
49+
See [Narwhals docs](https://narwhals-dev.github.io/narwhals/extending/){:target="_blank"} for an up-to-date list
50+
(and to learn how you can add your dataframe library to it!), though note that only those
51+
supported by [sklearn.utils.check_X_y](https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_X_y.html)
52+
will work with this class.
3553
"""
3654

3755
def __init__(self, model, column, lower=2, upper=2, method="sd"):
@@ -112,8 +130,9 @@ def fit(self, X, y=None):
112130
ValueError
113131
If the `model` is not a regression estimator.
114132
"""
115-
self.idx_ = np.argmax([i == self.column for i in X.columns]) if isinstance(X, pd.DataFrame) else self.column
116-
X = check_array(X, estimator=self)
133+
X = nw.from_native(X, eager_only=True, strict=False)
134+
self.idx_ = np.argmax([i == self.column for i in X.columns]) if isinstance(X, nw.DataFrame) else self.column
135+
X = check_array(nw.to_native(X, strict=False), estimator=self)
117136
if not self._is_regression_model():
118137
raise ValueError("Passed model must be regression!")
119138
X, y = self.to_x_y(X)

tests/test_meta/test_regression_outlier.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pandas as pd
3+
import polars as pl
34
import pytest
45
from sklearn.linear_model import LinearRegression, LogisticRegression
56

@@ -42,14 +43,15 @@ def test_obvious_example():
4243
assert preds[i] == -1
4344

4445

45-
def test_obvious_example_pandas():
46+
@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
47+
def test_obvious_example_dataframe(frame_func):
4648
# generate random data for illustrative example
4749
np.random.seed(42)
4850
x = np.random.normal(0, 1, 100)
4951
y = 1 + x + np.random.normal(0, 0.2, 100)
5052
for i in [20, 25, 50, 80]:
5153
y[i] += 2
52-
X = pd.DataFrame({"x": x, "y": y})
54+
X = frame_func({"x": x, "y": y})
5355

5456
# fit and plot
5557
mod = RegressionOutlierDetector(LinearRegression(), column="y")
@@ -58,14 +60,15 @@ def test_obvious_example_pandas():
5860
assert preds[i] == -1
5961

6062

61-
def test_raises_error():
63+
@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
64+
def test_raises_error(frame_func):
6265
# generate random data for illustrative example
6366
np.random.seed(42)
6467
x = np.random.normal(0, 1, 100)
6568
y = 1 + x + np.random.normal(0, 0.2, 100)
6669
for i in [20, 25, 50, 80]:
6770
y[i] += 2
68-
X = pd.DataFrame({"x": x, "y": y})
71+
X = frame_func({"x": x, "y": y})
6972

7073
with pytest.raises(ValueError):
7174
mod = RegressionOutlierDetector(LogisticRegression(), column="y")

0 commit comments

Comments
 (0)