3
3
import numpy as np
4
4
from sklearn .utils import check_random_state
5
5
from sklearn .preprocessing import label_binarize
6
+ from sklearn .model_selection import BaseCrossValidator
6
7
7
8
from mapie .conformity_scores .sets .naive import NaiveConformityScore
8
9
from mapie .conformity_scores .sets .utils import (
9
- check_include_last_label , check_proba_normalized
10
+ check_include_last_label
10
11
)
11
- from mapie .estimator .classifier import EnsembleClassifier
12
12
13
13
from mapie ._machine_precision import EPSILON
14
14
from numpy .typing import ArrayLike , NDArray
@@ -46,24 +46,28 @@ def get_predictions(
46
46
self ,
47
47
X : NDArray ,
48
48
alpha_np : NDArray ,
49
- estimator : EnsembleClassifier ,
49
+ y_pred_proba : NDArray ,
50
+ cv : Optional [Union [int , str , BaseCrossValidator ]],
50
51
agg_scores : Optional [str ] = "mean" ,
51
52
** kwargs
52
53
) -> NDArray :
53
54
"""
54
- Get predictions from an EnsembleClassifier .
55
+ Just processes the passed y_pred_proba .
55
56
56
57
Parameters
57
58
-----------
58
59
X: NDArray of shape (n_samples, n_features)
59
- Observed feature values.
60
+ Observed feature values (not used since predictions are passed) .
60
61
61
62
alpha_np: NDArray of shape (n_alpha,)
62
63
NDArray of floats between ``0`` and ``1``, represents the
63
64
uncertainty of the confidence interval.
64
65
65
- estimator: EnsembleClassifier
66
- Estimator that is fitted to predict y from X.
66
+ y_pred_proba: NDArray
67
+ Predicted probabilities from the estimator.
68
+
69
+ cv: Optional[Union[int, str, BaseCrossValidator]]
70
+ Cross-validation strategy used by the estimator.
67
71
68
72
agg_scores: Optional[str]
69
73
Method to aggregate the scores from the base estimators.
@@ -77,8 +81,6 @@ def get_predictions(
77
81
NDArray
78
82
Array of predictions.
79
83
"""
80
- y_pred_proba = estimator .predict (X , agg_scores )
81
- y_pred_proba = check_proba_normalized (y_pred_proba , axis = 1 )
82
84
if agg_scores != "crossval" :
83
85
y_pred_proba = np .repeat (
84
86
y_pred_proba [:, :, np .newaxis ], len (alpha_np ), axis = 2
@@ -171,7 +173,7 @@ def get_conformity_score_quantiles(
171
173
self ,
172
174
conformity_scores : NDArray ,
173
175
alpha_np : NDArray ,
174
- estimator : EnsembleClassifier ,
176
+ cv : Optional [ Union [ int , str , BaseCrossValidator ]] ,
175
177
agg_scores : Optional [str ] = "mean" ,
176
178
** kwargs
177
179
) -> NDArray :
@@ -187,8 +189,8 @@ def get_conformity_score_quantiles(
187
189
NDArray of floats between 0 and 1, representing the uncertainty
188
190
of the confidence interval.
189
191
190
- estimator: EnsembleClassifier
191
- Estimator that is fitted to predict y from X .
192
+ cv: Optional[Union[int, str, BaseCrossValidator]]
193
+ Cross-validation strategy used by the estimator .
192
194
193
195
agg_scores: Optional[str]
194
196
Method to aggregate the scores from the base estimators.
@@ -204,7 +206,7 @@ def get_conformity_score_quantiles(
204
206
"""
205
207
n = len (conformity_scores )
206
208
207
- if estimator . cv == "prefit" or agg_scores in ["mean" ]:
209
+ if cv == "prefit" or agg_scores in ["mean" ]:
208
210
quantiles_ = _compute_quantiles (conformity_scores , alpha_np )
209
211
else :
210
212
quantiles_ = (n + 1 ) * (1 - alpha_np )
@@ -328,7 +330,7 @@ def get_prediction_sets(
328
330
y_pred_proba : NDArray ,
329
331
conformity_scores : NDArray ,
330
332
alpha_np : NDArray ,
331
- estimator : EnsembleClassifier ,
333
+ cv : Optional [ Union [ int , str , BaseCrossValidator ]] ,
332
334
agg_scores : Optional [str ] = "mean" ,
333
335
include_last_label : Optional [Union [bool , str ]] = True ,
334
336
** kwargs
@@ -349,8 +351,8 @@ def get_prediction_sets(
349
351
NDArray of floats between 0 and 1, representing the uncertainty
350
352
of the confidence interval.
351
353
352
- estimator: EnsembleClassifier
353
- Estimator that is fitted to predict y from X .
354
+ cv: Optional[Union[int, str, BaseCrossValidator]]
355
+ Cross-validation strategy used by the estimator .
354
356
355
357
agg_scores: Optional[str]
356
358
Method to aggregate the scores from the base estimators.
@@ -398,7 +400,7 @@ def get_prediction_sets(
398
400
include_last_label = check_include_last_label (include_last_label )
399
401
400
402
# specify which thresholds will be used
401
- if estimator . cv == "prefit" or agg_scores in ["mean" ]:
403
+ if cv == "prefit" or agg_scores in ["mean" ]:
402
404
thresholds = self .quantiles_
403
405
else :
404
406
thresholds = conformity_scores .ravel ()
@@ -414,7 +416,7 @@ def get_prediction_sets(
414
416
)
415
417
)
416
418
# get the prediction set by taking all probabilities above the last one
417
- if estimator . cv == "prefit" or agg_scores in ["mean" ]:
419
+ if cv == "prefit" or agg_scores in ["mean" ]:
418
420
y_pred_included = np .greater_equal (
419
421
y_pred_proba - y_pred_proba_last , - EPSILON
420
422
)
@@ -432,7 +434,7 @@ def get_prediction_sets(
432
434
thresholds ,
433
435
** kwargs
434
436
)
435
- if estimator . cv == "prefit" or agg_scores in ["mean" ]:
437
+ if cv == "prefit" or agg_scores in ["mean" ]:
436
438
prediction_sets = y_pred_included
437
439
else :
438
440
# compute the number of times the inequality is verified
0 commit comments