Skip to content

Commit 37c79dc

Browse files
ENH: merge prototype with segmentation risk control code (#726)
1 parent afc234d commit 37c79dc

File tree

4 files changed

+195
-17
lines changed

4 files changed

+195
-17
lines changed

mapie/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
regression,
55
utils,
66
risk_control,
7+
risk_control_draft,
78
calibration,
89
subsample,
910
)
@@ -13,6 +14,7 @@
1314
"regression",
1415
"classification",
1516
"risk_control",
17+
"risk_control_draft",
1618
"calibration",
1719
"metrics",
1820
"utils",

mapie/control_risk/ltt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010

1111
def ltt_procedure(
12-
r_hat: NDArray,
13-
alpha_np: NDArray,
12+
r_hat: NDArray[np.float32],
13+
alpha_np: NDArray[np.float32],
1414
delta: Optional[float],
15-
n_obs: int
16-
) -> Tuple[List[List[Any]], NDArray]:
15+
n_obs: int,
16+
binary: bool = False, # TODO: maybe should pass p_values fonction instead
17+
) -> Tuple[List[List[Any]], NDArray[np.float32]]:
1718
"""
1819
Apply the Learn-Then-Test procedure for risk control.
1920
Note that we will do a multiple test for ``r_hat`` that are
@@ -63,13 +64,14 @@ def ltt_procedure(
6364
"Invalid delta: delta cannot be None while"
6465
+ " controlling precision with LTT. "
6566
)
66-
p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np)
67+
p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np, binary)
6768
N = len(p_values)
6869
valid_index = []
6970
for i in range(len(alpha_np)):
7071
l_index = np.where(p_values[:, i] <= delta/N)[0].tolist()
7172
valid_index.append(l_index)
72-
return valid_index, p_values
73+
return valid_index, p_values # TODO : p_values is not used, we could remove it
74+
# Or return corrected p_values
7375

7476

7577
def find_lambda_control_star(

mapie/control_risk/p_values.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99

1010
def compute_hoeffdding_bentkus_p_value(
11-
r_hat: NDArray,
11+
r_hat: NDArray[np.float32],
1212
n_obs: int,
13-
alpha: Union[float, NDArray]
14-
) -> NDArray:
13+
alpha: Union[float, NDArray[np.float32]],
14+
binary: bool = False,
15+
) -> NDArray[np.float32]:
1516
"""
1617
The method computes the p_values according to
1718
the Hoeffding_Bentkus inequality for each
@@ -63,18 +64,19 @@ def compute_hoeffdding_bentkus_p_value(
6364
)
6465
hoeffding_p_value = np.exp(
6566
-n_obs * _h1(
66-
np.where(
67+
np.where( # TODO : shouldn't we use np.minimum ?
6768
r_hat_repeat > alpha_repeat,
6869
alpha_repeat,
6970
r_hat_repeat
7071
),
7172
alpha_repeat
7273
)
7374
)
74-
bentkus_p_value = np.e * binom.cdf(
75+
factor = 1 if binary else np.e
76+
bentkus_p_value = factor * binom.cdf(
7577
np.ceil(n_obs * r_hat_repeat), n_obs, alpha_repeat
7678
)
77-
hb_p_value = np.where(
79+
hb_p_value = np.where( # TODO : shouldn't we use np.minimum ?
7880
bentkus_p_value > hoeffding_p_value,
7981
hoeffding_p_value,
8082
bentkus_p_value
@@ -83,9 +85,8 @@ def compute_hoeffdding_bentkus_p_value(
8385

8486

8587
def _h1(
86-
r_hats: NDArray,
87-
alphas: NDArray
88-
) -> NDArray:
88+
r_hats: NDArray[np.float32], alphas: NDArray[np.float32]
89+
) -> NDArray[np.float32]:
8990
"""
9091
This function allow us to compute
9192
the tighter version of hoeffding inequality.
@@ -114,6 +115,11 @@ def _h1(
114115
-------
115116
NDArray of shape a(n_lambdas, n_alpha).
116117
"""
117-
elt1 = r_hats * np.log(r_hats/alphas)
118-
elt2 = (1-r_hats) * np.log((1-r_hats)/(1-alphas))
118+
elt1 = np.zeros_like(r_hats, dtype=float)
119+
120+
# Compute only where r_hats != 0 to avoid log(0)
121+
# TODO: check Angelopoulos implementation
122+
mask = r_hats != 0
123+
elt1[mask] = r_hats[mask] * np.log(r_hats[mask] / alphas[mask])
124+
elt2 = (1 - r_hats) * np.log((1 - r_hats) / (1 - alphas))
119125
return elt1 + elt2

mapie/risk_control_draft.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from typing import Any, Optional, Union
2+
3+
import numpy as np
4+
from numpy._typing import ArrayLike, NDArray
5+
from sklearn.utils import check_random_state
6+
7+
from mapie.control_risk.ltt import ltt_procedure
8+
from mapie.utils import _check_n_jobs, _check_verbose
9+
10+
# General TODOs:
11+
# TODO: maybe use type float instead of float32?
12+
# TODO : in calibration and prediction,
13+
# use _transform_pred_proba or a function adapted to binary
14+
# to get the probabilities depending on the classifier
15+
16+
17+
class BinaryClassificationController: # pragma: no cover
18+
# TODO : test that this is working with a sklearn pipeline
19+
# TODO : test that this is working with a pandas dataframes
20+
"""
21+
Controller for the calibration of our binary classifier.
22+
23+
Parameters
24+
----------
25+
fitted_binary_classifier: Any
26+
Any object that provides a `predict_proba` method.
27+
28+
metric: str
29+
The performance metric we want to control (ex: "precision")
30+
31+
target_level: float
32+
The target performance level we want to achieve (ex: 0.8)
33+
34+
confidence_level: float
35+
The maximum acceptable probability of the precision falling below the
36+
target precision level (ex: 0.8)
37+
38+
Attributes
39+
----------
40+
precision_per_threshold: NDArray
41+
Precision of the binary classifier on the calibration set for each
42+
threshold from self._thresholds.
43+
44+
valid_threshold: NDArray
45+
Thresholds that meet the target precision with the desired confidence.
46+
47+
best_threshold: float
48+
Valid threshold that maximizes the recall, i.e. the smallest valid
49+
threshold.
50+
"""
51+
52+
def __init__(
53+
self,
54+
fitted_binary_classifier: Any,
55+
metric: str,
56+
target_level: float,
57+
confidence_level: float = 0.9,
58+
n_jobs: Optional[int] = None,
59+
random_state: Optional[Union[int, np.random.RandomState]] = None,
60+
verbose: int = 0
61+
):
62+
_check_n_jobs(n_jobs)
63+
_check_verbose(verbose)
64+
check_random_state(random_state)
65+
66+
self._classifier = fitted_binary_classifier
67+
self._alpha = 1 - target_level
68+
self._delta = 1 - confidence_level
69+
self._n_jobs = n_jobs # TODO : use this in the class or delete
70+
self._random_state = random_state # TODO : use this in the class or delete
71+
self._verbose = verbose # TODO : use this in the class or delete
72+
73+
self._thresholds: NDArray[np.float32] = np.arange(0, 1, 0.01)
74+
# TODO: add a _is_calibrated attribute to check at prediction time
75+
76+
self.valid_thresholds: Optional[NDArray[np.float32]] = None
77+
self.best_threshold: Optional[float] = None
78+
79+
def calibrate(self, X_calibrate: ArrayLike, y_calibrate: ArrayLike) -> None:
80+
"""
81+
Find the threshold that statistically guarantees the desired precision
82+
level while maximizing the recall.
83+
84+
Parameters
85+
----------
86+
X_calibrate: ArrayLike
87+
Features of the calibration set.
88+
89+
y_calibrate: ArrayLike
90+
True labels of the calibration set.
91+
92+
Raises
93+
------
94+
ValueError
95+
If no thresholds that meet the target precision with the desired
96+
confidence level are found.
97+
"""
98+
# TODO: Make sure this works with sklearn train_test_split/Series
99+
y_calibrate_ = np.asarray(y_calibrate)
100+
101+
predictions_proba = self._classifier.predict_proba(X_calibrate)[:, 1]
102+
103+
risk_per_threshold = 1 - self._compute_precision(
104+
predictions_proba, y_calibrate_
105+
)
106+
107+
valid_thresholds_index, _ = ltt_procedure(
108+
risk_per_threshold,
109+
np.array([self._alpha]),
110+
self._delta,
111+
len(y_calibrate_),
112+
True,
113+
)
114+
self.valid_thresholds = self._thresholds[valid_thresholds_index[0]]
115+
if len(self.valid_thresholds) == 0:
116+
# TODO: just warn, and raise error at prediction if no valid thresholds
117+
raise ValueError("No valid thresholds found")
118+
119+
# Minimum in case of precision control only
120+
self.best_threshold = min(self.valid_thresholds)
121+
122+
def predict(self, X_test: ArrayLike) -> NDArray:
123+
"""
124+
Predict binary labels on the test set, using the best threshold found
125+
during calibration.
126+
127+
Parameters
128+
----------
129+
X_test: ArrayLike
130+
Features of the test set.
131+
132+
Returns
133+
-------
134+
ArrayLike
135+
Predicted labels (0 or 1) for each sample in the test set.
136+
"""
137+
predictions_proba = self._classifier.predict_proba(X_test)[:, 1]
138+
return (predictions_proba >= self.best_threshold).astype(int)
139+
140+
def _compute_precision( # TODO: use sklearn or MAPIE ?
141+
self, predictions_proba: NDArray[np.float32], y_cal: NDArray[np.float32]
142+
) -> NDArray[np.float32]:
143+
"""
144+
Compute the precision for each threshold.
145+
"""
146+
predictions_per_threshold = (
147+
predictions_proba[:, np.newaxis] >= self._thresholds
148+
).astype(int)
149+
150+
true_positives = np.sum(
151+
(predictions_per_threshold == 1) & (y_cal[:, np.newaxis] == 1),
152+
axis=0,
153+
)
154+
false_positives = np.sum(
155+
(predictions_per_threshold == 1) & (y_cal[:, np.newaxis] == 0),
156+
axis=0,
157+
)
158+
159+
positive_predictions = true_positives + false_positives
160+
161+
# Avoid division by zero
162+
precision_per_threshold = np.ones_like(self._thresholds, dtype=float)
163+
nonzero_mask = positive_predictions > 0
164+
precision_per_threshold[nonzero_mask] = (
165+
true_positives[nonzero_mask] / positive_predictions[nonzero_mask]
166+
)
167+
168+
return precision_per_threshold

0 commit comments

Comments
 (0)