Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
91daf3e
ENH: implement BinaryClassificationRisk and related instances
Valentin-Laurent Jul 23, 2025
ded36fe
ENH: simplify BinaryClassificationRisk API
Valentin-Laurent Jul 25, 2025
f404ffe
ENH & MTN
Valentin-Laurent Jul 29, 2025
e8dae57
ENH & MTN & FIX
Valentin-Laurent Jul 30, 2025
a580aa4
TEST - hoeffdding_bentkus_p_value with n_obs as an array
Valentin-Laurent Jul 30, 2025
9e8b092
FIX - linting
Valentin-Laurent Jul 30, 2025
5fbc940
ENH - Performance, warning and docstring improvements
Valentin-Laurent Jul 30, 2025
cc88354
FIX - Fix local typing issue, investigate CI typing issues
Valentin-Laurent Jul 30, 2025
feb075d
FIX - Continue investigating CI typing issues
Valentin-Laurent Jul 30, 2025
bf28de0
MTN - remove relative import
Valentin-Laurent Jul 31, 2025
09e8751
ENH & TEST - Handle the case of undefined risk (ex: precision with no…
Valentin-Laurent Jul 31, 2025
1f18795
MTN - Revert formatting to avoid changes unrelated to current PR
Valentin-Laurent Aug 1, 2025
e661adb
MTN - Clarify code
Valentin-Laurent Aug 1, 2025
f232d5d
TEST - Fix test following handling of undefined risk
Valentin-Laurent Aug 27, 2025
df343ca
FIX - Fix typing issues in Python 3.9, revert CI back to normal
Valentin-Laurent Aug 27, 2025
8bf31fa
WIP - try to fix typing (can't reproduce locally)
Valentin-Laurent Aug 27, 2025
c819bcd
WIP - try to fix typing (can't reproduce locally)
Valentin-Laurent Aug 27, 2025
6b4fff5
WIP - try to fix typing (can't reproduce locally)
Valentin-Laurent Aug 27, 2025
e428c89
WIP - try to fix typing (can't reproduce locally)
Valentin-Laurent Aug 27, 2025
78017be
WIP - try to fix typing (can't reproduce locally)
Valentin-Laurent Aug 27, 2025
54aac1e
WIP - try to fix typing (can't reproduce locally)
Valentin-Laurent Aug 27, 2025
d8e615f
WIP - try to fix typing (can't reproduce locally)
Valentin-Laurent Aug 28, 2025
dda38d5
ENH - Add theoretical validity notebook to documentation
Valentin-Laurent Aug 28, 2025
42f48b2
FIX - Fix theoretical validity notebook
Valentin-Laurent Aug 28, 2025
a3a1ec2
FIX - Fix implementation error in BinaryClassificationController, imp…
Valentin-Laurent Aug 29, 2025
b96d0a1
ENH - Final tweaks to theoretical_validity_tests.ipynb
Valentin-Laurent Aug 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mapie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
regression,
utils,
risk_control,
risk_control_draft,
calibration,
subsample,
)
Expand All @@ -14,7 +13,6 @@
"regression",
"classification",
"risk_control",
"risk_control_draft",
"calibration",
"metrics",
"utils",
Expand Down
51 changes: 24 additions & 27 deletions mapie/control_risk/ltt.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
import warnings
from typing import Any, List, Optional, Tuple
from typing import Any, List, Tuple, Union

import numpy as np

from numpy.typing import ArrayLike, NDArray

from .p_values import compute_hoeffdding_bentkus_p_value
from mapie.control_risk.p_values import compute_hoeffdding_bentkus_p_value


def ltt_procedure(
r_hat: NDArray[np.float32],
alpha_np: NDArray[np.float32],
delta: Optional[float],
n_obs: int,
binary: bool = False, # TODO: maybe should pass p_values fonction instead
) -> Tuple[List[List[Any]], NDArray[np.float32]]:
r_hat: NDArray,
alpha_np: NDArray,
delta: float,
n_obs: Union[int, NDArray],
binary: bool = False,
) -> List[List[Any]]:
"""
Apply the Learn-Then-Test procedure for risk control.
Note that we will do a multiple test for ``r_hat`` that are
less than level ``alpha_np``.
The procedure follows the instructions in [1]:
- Calculate p-values for each lambdas descretized
- Apply a family wise error rate algorithm,
here Bonferonni correction
- Return the index lambdas that give you the control
at alpha level
- Calculate p-values for each lambdas discretized
- Apply a family wise error rate algorithm, here Bonferonni correction
- Return the index lambdas that give you the control at alpha level

Parameters
----------
r_hat: NDArray of shape (n_lambdas, ).
Empirical risk with respect
to the lambdas.
Here lambdas are thresholds that impact decision making,
Empirical risk with respect to the lambdas.
Here lambdas are thresholds that impact decision-making,
therefore empirical risk.

alpha_np: NDArray of shape (n_alpha, ).
Expand All @@ -44,34 +41,34 @@ def ltt_procedure(
Correspond to proportion of failure we don't
want to exceed.

n_obs: Union[int, NDArray]
Correspond to the number of observations used to compute the risk.
In the case of a conditional loss, n_obs must be the
number of effective observations used to compute the empirical risk
for each lambda, hence of shape (n_lambdas, ).

binary: bool, default=False
Must be True if the loss associated to the risk is binary.

Returns
-------
valid_index: List[List[Any]].
Contain the valid index that satisfy fwer control
Contain the valid index that satisfy FWER control
for each alpha (length aren't the same for each alpha).

p_values: NDArray of shape (n_lambda, n_alpha).
Contains the values of p_value for different alpha.

References
----------
[1] Angelopoulos, A. N., Bates, S., Candès, E. J., Jordan,
M. I., & Lei, L. (2021). Learn then test:
"Calibrating predictive algorithms to achieve risk control".
"""
if delta is None:
raise ValueError(
"Invalid delta: delta cannot be None while"
+ " controlling precision with LTT. "
)
p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np, binary)
N = len(p_values)
valid_index = []
for i in range(len(alpha_np)):
l_index = np.where(p_values[:, i] <= delta/N)[0].tolist()
valid_index.append(l_index)
return valid_index, p_values # TODO : p_values is not used, we could remove it
# Or return corrected p_values
return valid_index


def find_lambda_control_star(
Expand Down
53 changes: 31 additions & 22 deletions mapie/control_risk/p_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


def compute_hoeffdding_bentkus_p_value(
r_hat: NDArray[np.float32],
n_obs: int,
alpha: Union[float, NDArray[np.float32]],
r_hat: NDArray,
n_obs: Union[int, NDArray],
alpha: Union[float, NDArray],
binary: bool = False,
) -> NDArray[np.float32]:
) -> NDArray:
"""
The method computes the p_values according to
the Hoeffding_Bentkus inequality for each
Expand All @@ -30,16 +30,23 @@ def compute_hoeffdding_bentkus_p_value(
Here lambdas are thresholds that impact decision
making and therefore empirical risk.

n_obs: int.
Correspond to the number of observations in
dataset.
n_obs: Union[int, NDArray]
Correspond to the number of observations used to compute the risk.
In the case of a conditional loss, n_obs must be the
number of effective observations used to compute the empirical risk
for each lambda, hence of shape (n_lambdas, ).

alpha: Union[float, Iterable[float]].
Contains the different alphas control level.
The empirical risk must be less than alpha.
If it is a iterable, it is a NDArray of shape
(n_alpha, ).

binary: bool, default=False
Must be True if the loss associated to the risk is binary.
If True, we use a tighter version of the Bentkus p-value, valid when the
loss associated to the risk is binary. See section 3.2 of [1].

Returns
-------
hb_p_values: NDArray of shape (n_lambda, n_alpha).
Expand All @@ -62,9 +69,17 @@ def compute_hoeffdding_bentkus_p_value(
len(r_hat),
axis=0
)
if isinstance(n_obs, int):
n_obs = np.full_like(r_hat, n_obs, dtype=float)
n_obs_repeat = np.repeat(
np.expand_dims(n_obs, axis=1),
len(alpha_np),
axis=1
)

hoeffding_p_value = np.exp(
-n_obs * _h1(
np.where( # TODO : shouldn't we use np.minimum ?
-n_obs_repeat * _h1(
np.where(
r_hat_repeat > alpha_repeat,
alpha_repeat,
r_hat_repeat
Expand All @@ -74,9 +89,9 @@ def compute_hoeffdding_bentkus_p_value(
)
factor = 1 if binary else np.e
bentkus_p_value = factor * binom.cdf(
np.ceil(n_obs * r_hat_repeat), n_obs, alpha_repeat
np.ceil(n_obs_repeat * r_hat_repeat), n_obs_repeat, alpha_repeat
)
hb_p_value = np.where( # TODO : shouldn't we use np.minimum ?
hb_p_value = np.where(
bentkus_p_value > hoeffding_p_value,
hoeffding_p_value,
bentkus_p_value
Expand All @@ -85,14 +100,11 @@ def compute_hoeffdding_bentkus_p_value(


def _h1(
r_hats: NDArray[np.float32], alphas: NDArray[np.float32]
) -> NDArray[np.float32]:
r_hats: NDArray, alphas: NDArray
) -> NDArray:
"""
This function allow us to compute
the tighter version of hoeffding inequality.
This function is then used in the
hoeffding_bentkus_p_value function for the
computation of p-values.
This function allow us to compute the tighter version of hoeffding inequality.
When r_hat = 0, the log is undefined, but the limit is 0, so we set the result to 0.

Parameters
----------
Expand All @@ -113,12 +125,9 @@ def _h1(

Returns
-------
NDArray of shape a(n_lambdas, n_alpha).
NDArray of shape (n_lambdas, n_alpha).
"""
elt1 = np.zeros_like(r_hats, dtype=float)

# Compute only where r_hats != 0 to avoid log(0)
# TODO: check Angelopoulos implementation
mask = r_hats != 0
elt1[mask] = r_hats[mask] * np.log(r_hats[mask] / alphas[mask])
elt2 = (1 - r_hats) * np.log((1 - r_hats) / (1 - alphas))
Expand Down
73 changes: 70 additions & 3 deletions mapie/risk_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from itertools import chain
from typing import Iterable, Optional, Sequence, Tuple, Union, cast
from typing import Iterable, Optional, Sequence, Tuple, Union, cast, Callable

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
Expand Down Expand Up @@ -681,8 +681,8 @@ def predict(
if self.metric_control == 'precision':
self.n_obs = len(self.risks)
self.r_hat = self.risks.mean(axis=0)
self.valid_index, self.p_values = ltt_procedure(
self.r_hat, alpha_np, delta, self.n_obs
self.valid_index = ltt_procedure(
self.r_hat, alpha_np, cast(float, delta), self.n_obs
)
self._check_valid_index(alpha_np)
self.lambdas_star, self.r_star = find_lambda_control_star(
Expand All @@ -706,3 +706,70 @@ def predict(
self.lambdas_star[np.newaxis, np.newaxis, :]
)
return y_pred, y_pred_proba_array


class BinaryClassificationRisk:
# Any risk that can be defined in the following way will work using the binary
# Hoeffding-Bentkus p-values used in MAPIE
# Take the example of precision in the docstring to explain how the class works.
def __init__(
self,
risk_occurrence: Callable[[int, int], int],
risk_condition: Callable[[int, int], bool],
higher_is_better: bool,
):
self.risk_occurrence = risk_occurrence
self.risk_condition = risk_condition
self.higher_is_better = higher_is_better

def get_value_and_effective_sample_size(
self,
y_true: NDArray, # shape (n_samples,), values in {0, 1}
y_pred: NDArray, # shape (n_samples,), values in {0, 1}
) -> Tuple[float, int]:
# float between 0 and 1, int between 0 and len(y_true)
# returns (1, -1) when the risk is not defined (condition never met)
# In this case, the corresponding lambda shouldn't be considered valid.
# In the current LTT implementation, providing n_obs=-1 will result
# in an infinite p_value, effectively invaliding the lambda
risk_occurrences = np.array([
self.risk_occurrence(y_true_i, y_pred_i)
for y_true_i, y_pred_i in zip(y_true, y_pred)
])
risk_conditions = np.array([
self.risk_condition(y_true_i, y_pred_i)
for y_true_i, y_pred_i in zip(y_true, y_pred)
])
effective_sample_size = len(y_true) - np.sum(~risk_conditions)
# Casting needed for MyPy with Python 3.9
effective_sample_size_int = cast(int, effective_sample_size)
if effective_sample_size_int != 0:
risk_sum: int = np.sum(risk_occurrences[risk_conditions])
risk_value = risk_sum / effective_sample_size_int
return risk_value, effective_sample_size_int
return 1, -1


precision = BinaryClassificationRisk(
risk_occurrence=lambda y_true, y_pred: int(y_pred == y_true),
risk_condition=lambda y_true, y_pred: y_pred == 1,
higher_is_better=True,
)

accuracy = BinaryClassificationRisk(
risk_occurrence=lambda y_true, y_pred: int(y_pred == y_true),
risk_condition=lambda y_true, y_pred: True,
higher_is_better=True,
)

recall = BinaryClassificationRisk(
risk_occurrence=lambda y_true, y_pred: int(y_pred == y_true),
risk_condition=lambda y_true, y_pred: y_true == 1,
higher_is_better=True,
)

_automatic_best_predict_param_choice = {
precision: recall,
recall: precision,
accuracy: accuracy,
}
Loading
Loading