Skip to content

Commit 0c8c12d

Browse files
ENH & MTN
- Use BinaryClassificationRisk to compute risk - Use warning instead of error when risk is not controled. Throw error when predicting - Remove useless check on lambda=None in ltt_procedure - Remove useless p_values from ltt_procedure outputs - Add possibility to pass an array of n_obs to ltt_procedure and subsequent p-values calculations (needed for binary classification)
1 parent 065921c commit 0c8c12d

File tree

6 files changed

+98
-179
lines changed

6 files changed

+98
-179
lines changed

mapie/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
regression,
55
utils,
66
risk_control,
7-
risk_control_draft,
87
calibration,
98
subsample,
109
)
@@ -14,7 +13,6 @@
1413
"regression",
1514
"classification",
1615
"risk_control",
17-
"risk_control_draft",
1816
"calibration",
1917
"metrics",
2018
"utils",

mapie/control_risk/ltt.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, List, Optional, Tuple
2+
from typing import Any, List, Optional, Tuple, Union
33

44
import numpy as np
55

@@ -9,29 +9,26 @@
99

1010

1111
def ltt_procedure(
12-
r_hat: NDArray[np.float32],
13-
alpha_np: NDArray[np.float32],
14-
delta: Optional[float],
15-
n_obs: int,
16-
binary: bool = False, # TODO: maybe should pass p_values fonction instead
17-
) -> Tuple[List[List[Any]], NDArray[np.float32]]:
12+
r_hat: NDArray[float],
13+
alpha_np: NDArray[float],
14+
delta: float,
15+
n_obs: Union[int, NDArray],
16+
binary: bool = False,
17+
) -> List[List[Any]]:
1818
"""
1919
Apply the Learn-Then-Test procedure for risk control.
2020
Note that we will do a multiple test for ``r_hat`` that are
2121
less than level ``alpha_np``.
2222
The procedure follows the instructions in [1]:
23-
- Calculate p-values for each lambdas descretized
24-
- Apply a family wise error rate algorithm,
25-
here Bonferonni correction
26-
- Return the index lambdas that give you the control
27-
at alpha level
23+
- Calculate p-values for each lambdas discretized
24+
- Apply a family wise error rate algorithm, here Bonferonni correction
25+
- Return the index lambdas that give you the control at alpha level
2826
2927
Parameters
3028
----------
3129
r_hat: NDArray of shape (n_lambdas, ).
32-
Empirical risk with respect
33-
to the lambdas.
34-
Here lambdas are thresholds that impact decision making,
30+
Empirical risk with respect to the lambdas.
31+
Here lambdas are thresholds that impact decision-making,
3532
therefore empirical risk.
3633
3734
alpha_np: NDArray of shape (n_alpha, ).
@@ -44,34 +41,34 @@ def ltt_procedure(
4441
Correspond to proportion of failure we don't
4542
want to exceed.
4643
44+
n_obs: Union[int, NDArray]
45+
Correspond to the number of observations used to compute the risk.
46+
In the case of a conditional loss, n_obs must be the
47+
number of effective observations used to compute the empirical risk
48+
for each lambda, hence of shape (n_lambdas, ).
49+
50+
binary: bool, default=False
51+
Must be True if the loss associated to the risk is binary.
52+
4753
Returns
4854
-------
4955
valid_index: List[List[Any]].
50-
Contain the valid index that satisfy fwer control
56+
Contain the valid index that satisfy FWER control
5157
for each alpha (length aren't the same for each alpha).
5258
53-
p_values: NDArray of shape (n_lambda, n_alpha).
54-
Contains the values of p_value for different alpha.
55-
5659
References
5760
----------
5861
[1] Angelopoulos, A. N., Bates, S., Candès, E. J., Jordan,
5962
M. I., & Lei, L. (2021). Learn then test:
6063
"Calibrating predictive algorithms to achieve risk control".
6164
"""
62-
if delta is None:
63-
raise ValueError(
64-
"Invalid delta: delta cannot be None while"
65-
+ " controlling precision with LTT. "
66-
)
6765
p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np, binary)
6866
N = len(p_values)
6967
valid_index = []
7068
for i in range(len(alpha_np)):
7169
l_index = np.where(p_values[:, i] <= delta/N)[0].tolist()
7270
valid_index.append(l_index)
73-
return valid_index, p_values # TODO : p_values is not used, we could remove it
74-
# Or return corrected p_values
71+
return valid_index
7572

7673

7774
def find_lambda_control_star(

mapie/control_risk/p_values.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99

1010
def compute_hoeffdding_bentkus_p_value(
11-
r_hat: NDArray[np.float32],
12-
n_obs: int,
13-
alpha: Union[float, NDArray[np.float32]],
11+
r_hat: NDArray[float],
12+
n_obs: Union[int, NDArray],
13+
alpha: Union[float, NDArray[float]],
1414
binary: bool = False,
15-
) -> NDArray[np.float32]:
15+
) -> NDArray[float]:
1616
"""
1717
The method computes the p_values according to
1818
the Hoeffding_Bentkus inequality for each
@@ -30,16 +30,23 @@ def compute_hoeffdding_bentkus_p_value(
3030
Here lambdas are thresholds that impact decision
3131
making and therefore empirical risk.
3232
33-
n_obs: int.
34-
Correspond to the number of observations in
35-
dataset.
33+
n_obs: Union[int, NDArray]
34+
Correspond to the number of observations used to compute the risk.
35+
In the case of a conditional loss, n_obs must be the
36+
number of effective observations used to compute the empirical risk
37+
for each lambda, hence of shape (n_lambdas, ).
3638
3739
alpha: Union[float, Iterable[float]].
3840
Contains the different alphas control level.
3941
The empirical risk must be less than alpha.
4042
If it is a iterable, it is a NDArray of shape
4143
(n_alpha, ).
4244
45+
binary: bool, default=False
46+
Must be True if the loss associated to the risk is binary.
47+
If True, we use a tighter version of the Bentkus p-value, valid when the
48+
loss associated to the risk is binary. See section 3.2 of [1].
49+
4350
Returns
4451
-------
4552
hb_p_values: NDArray of shape (n_lambda, n_alpha).
@@ -62,9 +69,17 @@ def compute_hoeffdding_bentkus_p_value(
6269
len(r_hat),
6370
axis=0
6471
)
72+
if isinstance(n_obs, int):
73+
n_obs = np.full_like(r_hat, n_obs, dtype=float)
74+
n_obs_repeat = np.repeat(
75+
np.expand_dims(n_obs, axis=1),
76+
len(alpha_np),
77+
axis=1
78+
)
79+
6580
hoeffding_p_value = np.exp(
66-
-n_obs * _h1(
67-
np.where( # TODO : shouldn't we use np.minimum ?
81+
-n_obs_repeat * _h1(
82+
np.where(
6883
r_hat_repeat > alpha_repeat,
6984
alpha_repeat,
7085
r_hat_repeat
@@ -74,9 +89,9 @@ def compute_hoeffdding_bentkus_p_value(
7489
)
7590
factor = 1 if binary else np.e
7691
bentkus_p_value = factor * binom.cdf(
77-
np.ceil(n_obs * r_hat_repeat), n_obs, alpha_repeat
92+
np.ceil(n_obs_repeat * r_hat_repeat), n_obs, alpha_repeat
7893
)
79-
hb_p_value = np.where( # TODO : shouldn't we use np.minimum ?
94+
hb_p_value = np.where(
8095
bentkus_p_value > hoeffding_p_value,
8196
hoeffding_p_value,
8297
bentkus_p_value
@@ -85,8 +100,8 @@ def compute_hoeffdding_bentkus_p_value(
85100

86101

87102
def _h1(
88-
r_hats: NDArray[np.float32], alphas: NDArray[np.float32]
89-
) -> NDArray[np.float32]:
103+
r_hats: NDArray[float], alphas: NDArray[float]
104+
) -> NDArray[float]:
90105
"""
91106
This function allow us to compute
92107
the tighter version of hoeffding inequality.
@@ -113,7 +128,7 @@ def _h1(
113128
114129
Returns
115130
-------
116-
NDArray of shape a(n_lambdas, n_alpha).
131+
NDArray of shape (n_lambdas, n_alpha).
117132
"""
118133
elt1 = np.zeros_like(r_hats, dtype=float)
119134

mapie/risk_control.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,8 @@ def predict(
681681
if self.metric_control == 'precision':
682682
self.n_obs = len(self.risks)
683683
self.r_hat = self.risks.mean(axis=0)
684-
self.valid_index, self.p_values = ltt_procedure(
685-
self.r_hat, alpha_np, delta, self.n_obs
684+
self.valid_index = ltt_procedure(
685+
self.r_hat, alpha_np, cast(float, delta), self.n_obs
686686
)
687687
self._check_valid_index(alpha_np)
688688
self.lambdas_star, self.r_star = find_lambda_control_star(
@@ -724,8 +724,8 @@ def __init__(
724724

725725
def get_value_and_effective_sample_size(
726726
self,
727-
y_true: NDArray[int], # shape (n_samples,), values in {0, 1}
728-
y_pred: NDArray[int], # shape (n_samples,), values in {0, 1}
727+
y_true: NDArray[int], # shape (n_samples,), values in {0, 1}
728+
y_pred: NDArray[int], # shape (n_samples,), values in {0, 1}
729729
) -> Optional[Tuple[float, int]]:
730730
# float between 0 and 1, int between 0 and len(y_true)
731731
risk_occurrences = [
@@ -765,4 +765,10 @@ def get_value_and_effective_sample_size(
765765
risk_occurrence=lambda y_true, y_pred: int(y_pred == y_true),
766766
risk_condition=lambda y_true, y_pred: y_true == 1,
767767
higher_is_better=True,
768-
)
768+
)
769+
770+
_automatic_best_predict_param_choice = {
771+
precision: recall,
772+
recall: precision,
773+
accuracy: accuracy,
774+
}

0 commit comments

Comments
 (0)