Skip to content

Commit 0151dcd

Browse files
FIX: avoid double inference when predicting in split conformal classification (#721)
* FIX: avoid double inference when predicting in split conformal classification: - move probabilities prediction to the main class instead of conformity scores classes - use those probabilities to compute y_pred - change EnsembleClassifier .predict function to .predict_agg_proba - improve check_proba_normalized following copilot suggestion
1 parent 5158a88 commit 0151dcd

File tree

11 files changed

+178
-175
lines changed

11 files changed

+178
-175
lines changed

mapie/classification.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
check_target, check_and_select_conformity_score,
2525
)
2626
from mapie.estimator.classifier import EnsembleClassifier
27-
from mapie.utils import (_check_alpha, _check_alpha_and_n_samples, _check_cv,
28-
_check_estimator_classification, _check_n_features_in,
29-
_check_n_jobs, _check_null_weight, _check_predict_params,
30-
_check_verbose)
27+
from mapie.utils import (
28+
_check_alpha, _check_alpha_and_n_samples, _check_cv,
29+
_check_estimator_classification, _check_n_features_in,
30+
_check_n_jobs, _check_null_weight, _check_predict_params,
31+
_check_verbose, check_proba_normalized,
32+
)
3133
from mapie.utils import (
3234
_transform_confidence_level_to_alpha_list,
3335
_raise_error_if_fit_called_in_prefit_mode,
@@ -1055,7 +1057,12 @@ def predict(
10551057
alpha = cast(Optional[NDArray], _check_alpha(alpha))
10561058

10571059
# Estimate predictions
1058-
y_pred = self.estimator_.single_estimator_.predict(X, **predict_params)
1060+
y_pred_proba = self.estimator_.single_estimator_.predict_proba(
1061+
X,
1062+
**predict_params
1063+
)
1064+
y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
1065+
y_pred = self.label_encoder_.inverse_transform(np.argmax(y_pred_proba, axis=1))
10591066
if alpha is None:
10601067
return y_pred
10611068

@@ -1067,9 +1074,17 @@ def predict(
10671074
_check_alpha_and_n_samples(alpha_np, n)
10681075

10691076
# Estimate prediction sets
1077+
if self.estimator_.cv != "prefit":
1078+
y_pred_proba = self.estimator_.predict_agg_proba(
1079+
X,
1080+
agg_scores,
1081+
**predict_params
1082+
)
1083+
10701084
prediction_sets = self.conformity_score_function_.predict_set(
10711085
X, alpha_np,
1072-
estimator=self.estimator_,
1086+
y_pred_proba=y_pred_proba,
1087+
cv=self.estimator_.cv,
10731088
conformity_scores=self.conformity_scores_,
10741089
include_last_label=include_last_label,
10751090
agg_scores=agg_scores,

mapie/conformity_scores/classification.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from mapie.conformity_scores.interface import BaseConformityScore
7-
from mapie.estimator.classifier import EnsembleClassifier
7+
from sklearn.model_selection import BaseCrossValidator
88

99
from numpy.typing import ArrayLike, NDArray
1010

@@ -53,7 +53,8 @@ def get_predictions(
5353
self,
5454
X: NDArray,
5555
alpha_np: NDArray,
56-
estimator: EnsembleClassifier,
56+
y_pred_proba: NDArray,
57+
cv: Optional[Union[int, str, BaseCrossValidator]],
5758
**kwargs
5859
) -> NDArray:
5960
"""
@@ -84,7 +85,7 @@ def get_conformity_score_quantiles(
8485
self,
8586
conformity_scores: NDArray,
8687
alpha_np: NDArray,
87-
estimator: EnsembleClassifier,
88+
cv: Optional[Union[int, str, BaseCrossValidator]],
8889
**kwargs
8990
) -> NDArray:
9091
"""
@@ -116,7 +117,7 @@ def get_prediction_sets(
116117
y_pred_proba: NDArray,
117118
conformity_scores: NDArray,
118119
alpha_np: NDArray,
119-
estimator: EnsembleClassifier,
120+
cv: Optional[Union[int, str, BaseCrossValidator]],
120121
**kwargs
121122
) -> NDArray:
122123
"""
@@ -150,13 +151,14 @@ def get_sets(
150151
self,
151152
X: NDArray,
152153
alpha_np: NDArray,
153-
estimator: EnsembleClassifier,
154+
y_pred_proba: NDArray,
155+
cv: Optional[Union[int, str, BaseCrossValidator]],
154156
conformity_scores: NDArray,
155157
**kwargs
156158
) -> NDArray:
157159
"""
158160
Compute classes of the prediction sets from the observed values,
159-
the estimator of type ``EnsembleClassifier`` and the conformity scores.
161+
the predicted probabilities and the conformity scores.
160162
161163
Parameters
162164
----------
@@ -167,8 +169,11 @@ def get_sets(
167169
NDArray of floats between 0 and 1, representing the uncertainty
168170
of the confidence set.
169171
170-
estimator: EnsembleClassifier
171-
Estimator that is fitted to predict y from X.
172+
y_pred_proba: NDArray
173+
Predicted probabilities from the estimator.
174+
175+
cv: Optional[Union[int, str, BaseCrossValidator]]
176+
Cross-validation strategy used by the estimator.
172177
173178
conformity_scores: NDArray of shape (n_samples,)
174179
Conformity scores.
@@ -178,19 +183,19 @@ def get_sets(
178183
NDArray of shape (n_samples, n_classes, n_alpha)
179184
Prediction sets (Booleans indicate whether classes are included).
180185
"""
186+
# Choice of the quantile
181187
# Predict probabilities
182188
y_pred_proba = self.get_predictions(
183-
X, alpha_np, estimator, **kwargs
189+
X, alpha_np, y_pred_proba, cv, **kwargs
184190
)
185191

186-
# Choice of the quantile
187192
self.quantiles_ = self.get_conformity_score_quantiles(
188-
conformity_scores, alpha_np, estimator, **kwargs
193+
conformity_scores, alpha_np, cv, **kwargs
189194
)
190195

191196
# Build prediction sets
192197
prediction_sets = self.get_prediction_sets(
193-
y_pred_proba, conformity_scores, alpha_np, estimator, **kwargs
198+
y_pred_proba, conformity_scores, alpha_np, cv, **kwargs
194199
)
195200

196201
return prediction_sets
@@ -213,6 +218,12 @@ def predict_set(
213218
alpha_np: NDArray of shape (n_alpha, )
214219
Represents the uncertainty of the confidence set to produce.
215220
221+
y_pred_proba: NDArray
222+
Predicted probabilities from the estimator.
223+
224+
cv: Optional[Union[int, str, BaseCrossValidator]]
225+
Cross-validation strategy used by the estimator.
226+
216227
**kwargs: dict
217228
Additional keyword arguments.
218229
@@ -221,4 +232,6 @@ def predict_set(
221232
The output structure depend on the ``get_sets`` method.
222233
The prediction sets for each sample and each alpha level.
223234
"""
224-
return self.get_sets(X=X, alpha_np=alpha_np, **kwargs)
235+
return self.get_sets(
236+
X=X, alpha_np=alpha_np, **kwargs
237+
)

mapie/conformity_scores/sets/aps.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import numpy as np
44
from sklearn.utils import check_random_state
55
from sklearn.preprocessing import label_binarize
6+
from sklearn.model_selection import BaseCrossValidator
67

78
from mapie.conformity_scores.sets.naive import NaiveConformityScore
89
from mapie.conformity_scores.sets.utils import (
9-
check_include_last_label, check_proba_normalized
10+
check_include_last_label
1011
)
11-
from mapie.estimator.classifier import EnsembleClassifier
1212

1313
from mapie._machine_precision import EPSILON
1414
from numpy.typing import ArrayLike, NDArray
@@ -46,24 +46,28 @@ def get_predictions(
4646
self,
4747
X: NDArray,
4848
alpha_np: NDArray,
49-
estimator: EnsembleClassifier,
49+
y_pred_proba: NDArray,
50+
cv: Optional[Union[int, str, BaseCrossValidator]],
5051
agg_scores: Optional[str] = "mean",
5152
**kwargs
5253
) -> NDArray:
5354
"""
54-
Get predictions from an EnsembleClassifier.
55+
Just processes the passed y_pred_proba.
5556
5657
Parameters
5758
-----------
5859
X: NDArray of shape (n_samples, n_features)
59-
Observed feature values.
60+
Observed feature values (not used since predictions are passed).
6061
6162
alpha_np: NDArray of shape (n_alpha,)
6263
NDArray of floats between ``0`` and ``1``, represents the
6364
uncertainty of the confidence interval.
6465
65-
estimator: EnsembleClassifier
66-
Estimator that is fitted to predict y from X.
66+
y_pred_proba: NDArray
67+
Predicted probabilities from the estimator.
68+
69+
cv: Optional[Union[int, str, BaseCrossValidator]]
70+
Cross-validation strategy used by the estimator.
6771
6872
agg_scores: Optional[str]
6973
Method to aggregate the scores from the base estimators.
@@ -77,8 +81,6 @@ def get_predictions(
7781
NDArray
7882
Array of predictions.
7983
"""
80-
y_pred_proba = estimator.predict(X, agg_scores)
81-
y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
8284
if agg_scores != "crossval":
8385
y_pred_proba = np.repeat(
8486
y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2
@@ -171,7 +173,7 @@ def get_conformity_score_quantiles(
171173
self,
172174
conformity_scores: NDArray,
173175
alpha_np: NDArray,
174-
estimator: EnsembleClassifier,
176+
cv: Optional[Union[int, str, BaseCrossValidator]],
175177
agg_scores: Optional[str] = "mean",
176178
**kwargs
177179
) -> NDArray:
@@ -187,8 +189,8 @@ def get_conformity_score_quantiles(
187189
NDArray of floats between 0 and 1, representing the uncertainty
188190
of the confidence interval.
189191
190-
estimator: EnsembleClassifier
191-
Estimator that is fitted to predict y from X.
192+
cv: Optional[Union[int, str, BaseCrossValidator]]
193+
Cross-validation strategy used by the estimator.
192194
193195
agg_scores: Optional[str]
194196
Method to aggregate the scores from the base estimators.
@@ -204,7 +206,7 @@ def get_conformity_score_quantiles(
204206
"""
205207
n = len(conformity_scores)
206208

207-
if estimator.cv == "prefit" or agg_scores in ["mean"]:
209+
if cv == "prefit" or agg_scores in ["mean"]:
208210
quantiles_ = _compute_quantiles(conformity_scores, alpha_np)
209211
else:
210212
quantiles_ = (n + 1) * (1 - alpha_np)
@@ -328,7 +330,7 @@ def get_prediction_sets(
328330
y_pred_proba: NDArray,
329331
conformity_scores: NDArray,
330332
alpha_np: NDArray,
331-
estimator: EnsembleClassifier,
333+
cv: Optional[Union[int, str, BaseCrossValidator]],
332334
agg_scores: Optional[str] = "mean",
333335
include_last_label: Optional[Union[bool, str]] = True,
334336
**kwargs
@@ -349,8 +351,8 @@ def get_prediction_sets(
349351
NDArray of floats between 0 and 1, representing the uncertainty
350352
of the confidence interval.
351353
352-
estimator: EnsembleClassifier
353-
Estimator that is fitted to predict y from X.
354+
cv: Optional[Union[int, str, BaseCrossValidator]]
355+
Cross-validation strategy used by the estimator.
354356
355357
agg_scores: Optional[str]
356358
Method to aggregate the scores from the base estimators.
@@ -398,7 +400,7 @@ def get_prediction_sets(
398400
include_last_label = check_include_last_label(include_last_label)
399401

400402
# specify which thresholds will be used
401-
if estimator.cv == "prefit" or agg_scores in ["mean"]:
403+
if cv == "prefit" or agg_scores in ["mean"]:
402404
thresholds = self.quantiles_
403405
else:
404406
thresholds = conformity_scores.ravel()
@@ -414,7 +416,7 @@ def get_prediction_sets(
414416
)
415417
)
416418
# get the prediction set by taking all probabilities above the last one
417-
if estimator.cv == "prefit" or agg_scores in ["mean"]:
419+
if cv == "prefit" or agg_scores in ["mean"]:
418420
y_pred_included = np.greater_equal(
419421
y_pred_proba - y_pred_proba_last, -EPSILON
420422
)
@@ -432,7 +434,7 @@ def get_prediction_sets(
432434
thresholds,
433435
**kwargs
434436
)
435-
if estimator.cv == "prefit" or agg_scores in ["mean"]:
437+
if cv == "prefit" or agg_scores in ["mean"]:
436438
prediction_sets = y_pred_included
437439
else:
438440
# compute the number of times the inequality is verified

0 commit comments

Comments
 (0)