Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ Contributors
* Faustin Pulvéric <[email protected]>
* Chaoqi Zhang <[email protected]>
* Leena Kamran Qidwai
* Aman Vishnoi <[email protected]>
To be continued ...
1 change: 1 addition & 0 deletions mapie/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ def fit(
X=X,
sample_weight=sample_weight,
groups=groups,
predict_params=predict_params,
)
return self

Expand Down
12 changes: 9 additions & 3 deletions mapie/conformity_scores/sets/raps.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,15 @@ def get_conformity_scores(
Conformity scores.
"""
# Compute y_pred and position on the RAPS validation dataset
self.y_pred_proba_raps = self.predictor.single_estimator_.predict_proba(
self.X_raps
)
predict_params = kwargs.pop("predict_params", None)
if predict_params is not None and len(predict_params) > 0:
self.y_pred_proba_raps = self.predictor.single_estimator_.predict_proba(
self.X_raps, **predict_params
)
else:
self.y_pred_proba_raps = self.predictor.single_estimator_.predict_proba(
self.X_raps
)
self.position_raps = get_true_label_position(
self.y_pred_proba_raps, self.y_raps
)
Expand Down
2 changes: 1 addition & 1 deletion mapie/estimator/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def predict_proba_calib(
check_is_fitted(self, self.fit_attributes)

if self.cv == "prefit":
y_pred_proba = self.single_estimator_.predict_proba(X)
y_pred_proba = self.single_estimator_.predict_proba(X, **predict_params)
y_pred_proba = self._check_proba_normalized(y_pred_proba)
else:
X = cast(NDArray, X)
Expand Down
36 changes: 36 additions & 0 deletions mapie/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,42 @@ def test_predict_parameters_passing() -> None:
np.testing.assert_equal(y_pred, 0)


def test_raps_with_predict_params() -> None:
"""Test that predict_params are correctly passed when using RAPS."""
X, y = make_classification(
n_samples=500,
n_features=10,
n_informative=3,
n_classes=3,
random_state=random_state,
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=random_state
)
estimator = CustomGradientBoostingClassifier(random_state=random_state)
predict_params = {"check_predict_params": True}
mapie_clf = _MapieClassifier(
estimator=estimator,
conformity_score=RAPSConformityScore(size_raps=0.2),
cv="split",
test_size=0.2,
random_state=random_state,
)

mapie_clf.fit(X_train, y_train, predict_params=predict_params)

y_pred, y_ps = mapie_clf.predict(
X_test,
alpha=0.1,
include_last_label="randomized",
agg_scores="mean",
**predict_params,
)
# Ensure the output shapes are correct
assert y_pred.shape == (X_test.shape[0],)
assert y_ps.shape == (X_test.shape[0], len(np.unique(y)), 1)


def test_with_no_predict_parameters_passing() -> None:
"""
Test passing with no predict parameters from the
Expand Down
Loading