Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions mapie/estimator/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from joblib import Parallel, delayed
from sklearn.base import ClassifierMixin, clone
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from sklearn.model_selection import (BaseCrossValidator, BaseShuffleSplit)
from sklearn.utils import _safe_indexing
from sklearn.utils.validation import _num_samples, check_is_fitted
Expand All @@ -13,7 +13,7 @@
from mapie.utils import _check_no_agg_cv, _fit_estimator, _fix_number_of_classes


class EnsembleClassifier:
class EnsembleClassifier(BaseEstimator):
"""
This class implements methods to handle the training and usage of the
estimator. This estimator can be unique or composed by cross validated
Expand Down Expand Up @@ -402,7 +402,6 @@ def predict_proba_calib(
NDArray of shape (n_samples_test, 1)
The predictions.
"""
check_is_fitted(self, self.fit_attributes)

if self.cv == "prefit":
y_pred_proba = self.single_estimator_.predict_proba(X)
Expand Down
5 changes: 2 additions & 3 deletions mapie/estimator/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from joblib import Parallel, delayed
from sklearn.base import RegressorMixin, clone
from sklearn.base import BaseEstimator, RegressorMixin, clone
from sklearn.model_selection import BaseCrossValidator
from sklearn.utils import _safe_indexing
from sklearn.utils.validation import _num_samples, check_is_fitted
Expand All @@ -15,7 +15,7 @@
_fit_estimator)


class EnsembleRegressor:
class EnsembleRegressor(BaseEstimator):
"""
This class implements methods to handle the training and usage of the
estimator. This estimator can be unique or composed by cross validated
Expand Down Expand Up @@ -364,7 +364,6 @@ def predict_calib(
NDArray of shape (n_samples_test, 1)
The predictions.
"""
check_is_fitted(self, self.fit_attributes)

if self.cv == "prefit":
y_pred = self.single_estimator_.predict(X)
Expand Down
2 changes: 1 addition & 1 deletion mapie/risk_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .utils import _check_alpha, _check_n_jobs, _check_verbose


class PrecisionRecallController(BaseEstimator, ClassifierMixin):
class PrecisionRecallController(ClassifierMixin, BaseEstimator):
"""
Prediction sets for multilabel-classification.

Expand Down
3 changes: 2 additions & 1 deletion mapie/tests/risk_control/test_precision_recall_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.base import BaseEstimator
from sklearn.compose import ColumnTransformer
from sklearn.datasets import make_multilabel_classification
from sklearn.impute import SimpleImputer
Expand Down Expand Up @@ -152,7 +153,7 @@ def predict(self, *args: Any):
"""Dummy predict."""


class ArrayOutputModel:
class ArrayOutputModel(BaseEstimator):

def __init__(self):
self.trained_ = True
Expand Down
12 changes: 6 additions & 6 deletions mapie/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.base import ClassifierMixin
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.compose import ColumnTransformer
from sklearn.datasets import make_classification
from sklearn.dummy import DummyClassifier
Expand Down Expand Up @@ -811,7 +811,7 @@ def early_stopping_monitor(i, est, locals):
}


class CumulatedScoreClassifier:
class CumulatedScoreClassifier(BaseEstimator):

def __init__(self) -> None:
self.X_calib = np.array([0, 1, 2]).reshape(-1, 1)
Expand Down Expand Up @@ -845,7 +845,7 @@ def predict_proba(self, X: ArrayLike) -> NDArray:
)


class ImageClassifier:
class ImageClassifier(BaseEstimator):

def __init__(self, X_calib: ArrayLike, X_test: ArrayLike) -> None:
self.X_calib = X_calib
Expand Down Expand Up @@ -874,7 +874,7 @@ def predict_proba(self, X: ArrayLike) -> NDArray:
)


class WrongOutputModel:
class WrongOutputModel(BaseEstimator):

def __init__(self, proba_out: NDArray):
self.trained_ = True
Expand All @@ -888,7 +888,7 @@ def predict_proba(self, *args: Any) -> NDArray:
return self.proba_out


class Float32OuputModel:
class Float32OutputModel(BaseEstimator):

def __init__(self, prefit: bool = True):
self.trained_ = prefit
Expand Down Expand Up @@ -1713,7 +1713,7 @@ def test_classif_float32(cv) -> None:
n_classes=3
)
alpha = .9
dummy_classif = Float32OuputModel()
dummy_classif = Float32OutputModel()

mapie = _MapieClassifier(
estimator=dummy_classif, conformity_score=NaiveConformityScore(),
Expand Down
2 changes: 1 addition & 1 deletion mapie/tests/test_non_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def filter_params(
return {k: v for k, v in params.items() if k in model_params}


class DummyClassifierWithFitAndPredictParams(BaseEstimator, ClassifierMixin):
class DummyClassifierWithFitAndPredictParams(ClassifierMixin, BaseEstimator):
def __init__(self):
self.classes_ = None
self._dummy_fit_param = None
Expand Down
3 changes: 2 additions & 1 deletion mapie/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pytest
from numpy.random import RandomState
from sklearn.base import BaseEstimator
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import (BaseCrossValidator, KFold, LeaveOneOut,
Expand Down Expand Up @@ -321,7 +322,7 @@ def test_does_nothing_when_not_in_prefit_mode(self):
}


class DumbEstimator:
class DumbEstimator(BaseEstimator):
def fit(
self,
X: ArrayLike,
Expand Down
Loading