Skip to content

Commit 95d1ce6

Browse files
ENH: implement BinaryClassificationRisk and related instances
1 parent 37c79dc commit 95d1ce6

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

mapie/risk_control.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from itertools import chain
5-
from typing import Iterable, Optional, Sequence, Tuple, Union, cast
5+
from typing import Iterable, Optional, Sequence, Tuple, Union, cast, Callable
66

77
import numpy as np
88
from sklearn.base import BaseEstimator, ClassifierMixin
@@ -706,3 +706,53 @@ def predict(
706706
self.lambdas_star[np.newaxis, np.newaxis, :]
707707
)
708708
return y_pred, y_pred_proba_array
709+
710+
711+
class BinaryClassificationRisk:
712+
# Any risk that can be defined in the following way will work using the binary
713+
# Hoeffding-Bentkus p-values used in MAPIE
714+
# Take the example of precision in the docstring to explain how the class works.
715+
def __init__(
716+
self,
717+
occurrence: Callable[[int, int], Optional[int]],
718+
# (y_true, y_pred), output: int (0 or 1) or None if undefined
719+
higher_is_better: bool,
720+
):
721+
self.occurrence = occurrence
722+
self.higher_is_better = higher_is_better
723+
724+
def get_value_and_effective_sample_size(
725+
self,
726+
y_true: NDArray[int], # shape (n_samples,), values in {0, 1}
727+
y_pred: NDArray[int], # shape (n_samples,), values in {0, 1}
728+
) -> Optional[Tuple[float, int]]:
729+
# float between 0 and 1, int between 0 and len(y_true)
730+
risk_occurrences = [
731+
self.occurrence(y_true_i, y_pred_i)
732+
for y_true_i, y_pred_i in zip(y_true, y_pred)
733+
]
734+
effective_sample_size = len(y_true) - risk_occurrences.count(None)
735+
if effective_sample_size != 0:
736+
risk_value = sum(
737+
occurrence for occurrence in risk_occurrences if occurrence is not None
738+
) / effective_sample_size
739+
if self.higher_is_better:
740+
risk_value = 1 - risk_value
741+
return risk_value, effective_sample_size
742+
return None
743+
744+
745+
precision = BinaryClassificationRisk(
746+
occurrence=lambda y_true, y_pred: None if y_pred == 0 else int(y_pred == y_true),
747+
higher_is_better=True,
748+
)
749+
750+
accuracy = BinaryClassificationRisk(
751+
occurrence=lambda y_true, y_pred: int(y_pred == y_true),
752+
higher_is_better=True,
753+
)
754+
755+
recall = BinaryClassificationRisk(
756+
occurrence=lambda y_true, y_pred: None if y_true == 0 else int(y_pred == y_true),
757+
higher_is_better=True,
758+
)

0 commit comments

Comments
 (0)