2
2
3
3
import warnings
4
4
from itertools import chain
5
- from typing import Iterable , Optional , Sequence , Tuple , Union , cast
5
+ from typing import Iterable , Optional , Sequence , Tuple , Union , cast , Callable
6
6
7
7
import numpy as np
8
8
from sklearn .base import BaseEstimator , ClassifierMixin
@@ -706,3 +706,53 @@ def predict(
706
706
self .lambdas_star [np .newaxis , np .newaxis , :]
707
707
)
708
708
return y_pred , y_pred_proba_array
709
+
710
+
711
+ class BinaryClassificationRisk :
712
+ # Any risk that can be defined in the following way will work using the binary
713
+ # Hoeffding-Bentkus p-values used in MAPIE
714
+ # Take the example of precision in the docstring to explain how the class works.
715
+ def __init__ (
716
+ self ,
717
+ occurrence : Callable [[int , int ], Optional [int ]],
718
+ # (y_true, y_pred), output: int (0 or 1) or None if undefined
719
+ higher_is_better : bool ,
720
+ ):
721
+ self .occurrence = occurrence
722
+ self .higher_is_better = higher_is_better
723
+
724
+ def get_value_and_effective_sample_size (
725
+ self ,
726
+ y_true : NDArray [int ], # shape (n_samples,), values in {0, 1}
727
+ y_pred : NDArray [int ], # shape (n_samples,), values in {0, 1}
728
+ ) -> Optional [Tuple [float , int ]]:
729
+ # float between 0 and 1, int between 0 and len(y_true)
730
+ risk_occurrences = [
731
+ self .occurrence (y_true_i , y_pred_i )
732
+ for y_true_i , y_pred_i in zip (y_true , y_pred )
733
+ ]
734
+ effective_sample_size = len (y_true ) - risk_occurrences .count (None )
735
+ if effective_sample_size != 0 :
736
+ risk_value = sum (
737
+ occurrence for occurrence in risk_occurrences if occurrence is not None
738
+ ) / effective_sample_size
739
+ if self .higher_is_better :
740
+ risk_value = 1 - risk_value
741
+ return risk_value , effective_sample_size
742
+ return None
743
+
744
+
745
+ precision = BinaryClassificationRisk (
746
+ occurrence = lambda y_true , y_pred : None if y_pred == 0 else int (y_pred == y_true ),
747
+ higher_is_better = True ,
748
+ )
749
+
750
+ accuracy = BinaryClassificationRisk (
751
+ occurrence = lambda y_true , y_pred : int (y_pred == y_true ),
752
+ higher_is_better = True ,
753
+ )
754
+
755
+ recall = BinaryClassificationRisk (
756
+ occurrence = lambda y_true , y_pred : None if y_true == 0 else int (y_pred == y_true ),
757
+ higher_is_better = True ,
758
+ )
0 commit comments