|
| 1 | +from typing import Any, Optional, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from numpy._typing import ArrayLike, NDArray |
| 5 | +from sklearn.utils import check_random_state |
| 6 | + |
| 7 | +from mapie.control_risk.ltt import ltt_procedure |
| 8 | +from mapie.utils import _check_n_jobs, _check_verbose |
| 9 | + |
| 10 | +# General TODOs: |
| 11 | +# TODO: maybe use type float instead of float32? |
| 12 | +# TODO : in calibration and prediction, |
| 13 | +# use _transform_pred_proba or a function adapted to binary |
| 14 | +# to get the probabilities depending on the classifier |
| 15 | + |
| 16 | + |
| 17 | +class BinaryClassificationController: # pragma: no cover |
| 18 | + # TODO : test that this is working with a sklearn pipeline |
| 19 | + # TODO : test that this is working with a pandas dataframes |
| 20 | + """ |
| 21 | + Controller for the calibration of our binary classifier. |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + fitted_binary_classifier: Any |
| 26 | + Any object that provides a `predict_proba` method. |
| 27 | +
|
| 28 | + metric: str |
| 29 | + The performance metric we want to control (ex: "precision") |
| 30 | +
|
| 31 | + target_level: float |
| 32 | + The target performance level we want to achieve (ex: 0.8) |
| 33 | +
|
| 34 | + confidence_level: float |
| 35 | + The maximum acceptable probability of the precision falling below the |
| 36 | + target precision level (ex: 0.8) |
| 37 | +
|
| 38 | + Attributes |
| 39 | + ---------- |
| 40 | + precision_per_threshold: NDArray |
| 41 | + Precision of the binary classifier on the calibration set for each |
| 42 | + threshold from self._thresholds. |
| 43 | +
|
| 44 | + valid_threshold: NDArray |
| 45 | + Thresholds that meet the target precision with the desired confidence. |
| 46 | +
|
| 47 | + best_threshold: float |
| 48 | + Valid threshold that maximizes the recall, i.e. the smallest valid |
| 49 | + threshold. |
| 50 | + """ |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + fitted_binary_classifier: Any, |
| 55 | + metric: str, |
| 56 | + target_level: float, |
| 57 | + confidence_level: float = 0.9, |
| 58 | + n_jobs: Optional[int] = None, |
| 59 | + random_state: Optional[Union[int, np.random.RandomState]] = None, |
| 60 | + verbose: int = 0 |
| 61 | + ): |
| 62 | + _check_n_jobs(n_jobs) |
| 63 | + _check_verbose(verbose) |
| 64 | + check_random_state(random_state) |
| 65 | + |
| 66 | + self._classifier = fitted_binary_classifier |
| 67 | + self._alpha = 1 - target_level |
| 68 | + self._delta = 1 - confidence_level |
| 69 | + self._n_jobs = n_jobs # TODO : use this in the class or delete |
| 70 | + self._random_state = random_state # TODO : use this in the class or delete |
| 71 | + self._verbose = verbose # TODO : use this in the class or delete |
| 72 | + |
| 73 | + self._thresholds: NDArray[np.float32] = np.arange(0, 1, 0.01) |
| 74 | + # TODO: add a _is_calibrated attribute to check at prediction time |
| 75 | + |
| 76 | + self.valid_thresholds: Optional[NDArray[np.float32]] = None |
| 77 | + self.best_threshold: Optional[float] = None |
| 78 | + |
| 79 | + def calibrate(self, X_calibrate: ArrayLike, y_calibrate: ArrayLike) -> None: |
| 80 | + """ |
| 81 | + Find the threshold that statistically guarantees the desired precision |
| 82 | + level while maximizing the recall. |
| 83 | +
|
| 84 | + Parameters |
| 85 | + ---------- |
| 86 | + X_calibrate: ArrayLike |
| 87 | + Features of the calibration set. |
| 88 | +
|
| 89 | + y_calibrate: ArrayLike |
| 90 | + True labels of the calibration set. |
| 91 | +
|
| 92 | + Raises |
| 93 | + ------ |
| 94 | + ValueError |
| 95 | + If no thresholds that meet the target precision with the desired |
| 96 | + confidence level are found. |
| 97 | + """ |
| 98 | + # TODO: Make sure this works with sklearn train_test_split/Series |
| 99 | + y_calibrate_ = np.asarray(y_calibrate) |
| 100 | + |
| 101 | + predictions_proba = self._classifier.predict_proba(X_calibrate)[:, 1] |
| 102 | + |
| 103 | + risk_per_threshold = 1 - self._compute_precision( |
| 104 | + predictions_proba, y_calibrate_ |
| 105 | + ) |
| 106 | + |
| 107 | + valid_thresholds_index, _ = ltt_procedure( |
| 108 | + risk_per_threshold, |
| 109 | + np.array([self._alpha]), |
| 110 | + self._delta, |
| 111 | + len(y_calibrate_), |
| 112 | + True, |
| 113 | + ) |
| 114 | + self.valid_thresholds = self._thresholds[valid_thresholds_index[0]] |
| 115 | + if len(self.valid_thresholds) == 0: |
| 116 | + # TODO: just warn, and raise error at prediction if no valid thresholds |
| 117 | + raise ValueError("No valid thresholds found") |
| 118 | + |
| 119 | + # Minimum in case of precision control only |
| 120 | + self.best_threshold = min(self.valid_thresholds) |
| 121 | + |
| 122 | + def predict(self, X_test: ArrayLike) -> NDArray: |
| 123 | + """ |
| 124 | + Predict binary labels on the test set, using the best threshold found |
| 125 | + during calibration. |
| 126 | +
|
| 127 | + Parameters |
| 128 | + ---------- |
| 129 | + X_test: ArrayLike |
| 130 | + Features of the test set. |
| 131 | +
|
| 132 | + Returns |
| 133 | + ------- |
| 134 | + ArrayLike |
| 135 | + Predicted labels (0 or 1) for each sample in the test set. |
| 136 | + """ |
| 137 | + predictions_proba = self._classifier.predict_proba(X_test)[:, 1] |
| 138 | + return (predictions_proba >= self.best_threshold).astype(int) |
| 139 | + |
| 140 | + def _compute_precision( # TODO: use sklearn or MAPIE ? |
| 141 | + self, predictions_proba: NDArray[np.float32], y_cal: NDArray[np.float32] |
| 142 | + ) -> NDArray[np.float32]: |
| 143 | + """ |
| 144 | + Compute the precision for each threshold. |
| 145 | + """ |
| 146 | + predictions_per_threshold = ( |
| 147 | + predictions_proba[:, np.newaxis] >= self._thresholds |
| 148 | + ).astype(int) |
| 149 | + |
| 150 | + true_positives = np.sum( |
| 151 | + (predictions_per_threshold == 1) & (y_cal[:, np.newaxis] == 1), |
| 152 | + axis=0, |
| 153 | + ) |
| 154 | + false_positives = np.sum( |
| 155 | + (predictions_per_threshold == 1) & (y_cal[:, np.newaxis] == 0), |
| 156 | + axis=0, |
| 157 | + ) |
| 158 | + |
| 159 | + positive_predictions = true_positives + false_positives |
| 160 | + |
| 161 | + # Avoid division by zero |
| 162 | + precision_per_threshold = np.ones_like(self._thresholds, dtype=float) |
| 163 | + nonzero_mask = positive_predictions > 0 |
| 164 | + precision_per_threshold[nonzero_mask] = ( |
| 165 | + true_positives[nonzero_mask] / positive_predictions[nonzero_mask] |
| 166 | + ) |
| 167 | + |
| 168 | + return precision_per_threshold |
0 commit comments