1+ import narwhals as nw
12import numpy as np
2- import pandas as pd
33from sklearn .base import BaseEstimator , OutlierMixin
44from sklearn .utils .validation import check_array , check_is_fitted
55
@@ -11,8 +11,11 @@ class RegressionOutlierDetector(BaseEstimator, OutlierMixin):
1111 ----------
1212 model : scikit-learn compatible regression model
1313 A regression model that will be used for prediction.
14- column : int
15- The index of the target column to predict in the input data.
14+ column : int | str
15+ This should be:
16+
17+ - The index of the target column to predict in the input data, when the input is an array.
18+ - The name of the target column to predict in the input data, when the input is a dataframe.
1619 lower : float, default=2.0
1720 Lower threshold for outlier detection. The method used for detection depends on the `method` parameter.
1821 upper : float, default=2.0
@@ -32,6 +35,21 @@ class RegressionOutlierDetector(BaseEstimator, OutlierMixin):
3235 The standard deviation of the differences between true and predicted values.
3336 idx_ : int
3437 The index of the target column in the input data.
38+
39+ Notes
40+ -----
41+ Native cross-dataframe support is achieved using
42+ [Narwhals](https://narwhals-dev.github.io/narwhals/){:target="_blank"}.
43+ Supported dataframes are:
44+
45+ - pandas
46+ - Polars (eager)
47+ - Modin
48+
49+ See [Narwhals docs](https://narwhals-dev.github.io/narwhals/extending/){:target="_blank"} for an up-to-date list
50+ (and to learn how you can add your dataframe library to it!), though note that only those
51+ supported by [sklearn.utils.check_X_y](https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_X_y.html)
52+ will work with this class.
3553 """
3654
3755 def __init__ (self , model , column , lower = 2 , upper = 2 , method = "sd" ):
@@ -112,8 +130,9 @@ def fit(self, X, y=None):
112130 ValueError
113131 If the `model` is not a regression estimator.
114132 """
115- self .idx_ = np .argmax ([i == self .column for i in X .columns ]) if isinstance (X , pd .DataFrame ) else self .column
116- X = check_array (X , estimator = self )
133+ X = nw .from_native (X , eager_only = True , strict = False )
134+ self .idx_ = np .argmax ([i == self .column for i in X .columns ]) if isinstance (X , nw .DataFrame ) else self .column
135+ X = check_array (nw .to_native (X , strict = False ), estimator = self )
117136 if not self ._is_regression_model ():
118137 raise ValueError ("Passed model must be regression!" )
119138 X , y = self .to_x_y (X )
0 commit comments