@@ -714,11 +714,12 @@ class BinaryClassificationRisk:
714
714
# Take the example of precision in the docstring to explain how the class works.
715
715
def __init__ (
716
716
self ,
717
- occurrence : Callable [[int , int ], Optional [ int ] ],
718
- # (y_true, y_pred), output: int (0 or 1) or None if undefined
717
+ risk_occurrence : Callable [[int , int ], int ],
718
+ risk_condition : Callable [[ int , int ], bool ],
719
719
higher_is_better : bool ,
720
720
):
721
- self .occurrence = occurrence
721
+ self .risk_occurrence = risk_occurrence
722
+ self .risk_condition = risk_condition
722
723
self .higher_is_better = higher_is_better
723
724
724
725
def get_value_and_effective_sample_size (
@@ -728,31 +729,40 @@ def get_value_and_effective_sample_size(
728
729
) -> Optional [Tuple [float , int ]]:
729
730
# float between 0 and 1, int between 0 and len(y_true)
730
731
risk_occurrences = [
731
- self .occurrence (y_true_i , y_pred_i )
732
+ self .risk_occurrence (y_true_i , y_pred_i )
732
733
for y_true_i , y_pred_i in zip (y_true , y_pred )
733
734
]
734
- effective_sample_size = len (y_true ) - risk_occurrences .count (None )
735
+ risk_conditions = [
736
+ self .risk_condition (y_true_i , y_pred_i )
737
+ for y_true_i , y_pred_i in zip (y_true , y_pred )
738
+ ]
739
+ effective_sample_size = len (y_true ) - risk_conditions .count (False )
735
740
if effective_sample_size != 0 :
736
- risk_value = sum (
737
- occurrence for occurrence in risk_occurrences if occurrence is not None
738
- ) / effective_sample_size
741
+ risk_sum = sum (
742
+ risk_occurrence for risk_occurrence , risk_condition
743
+ in zip (risk_occurrences , risk_conditions )
744
+ if risk_condition )
745
+ risk_value = risk_sum / effective_sample_size
739
746
if self .higher_is_better :
740
747
risk_value = 1 - risk_value
741
748
return risk_value , effective_sample_size
742
749
return None
743
750
744
751
745
752
precision = BinaryClassificationRisk (
746
- occurrence = lambda y_true , y_pred : None if y_pred == 0 else int (y_pred == y_true ),
753
+ risk_occurrence = lambda y_true , y_pred : int (y_pred == y_true ),
754
+ risk_condition = lambda y_true , y_pred : y_pred == 1 ,
747
755
higher_is_better = True ,
748
756
)
749
757
750
758
accuracy = BinaryClassificationRisk (
751
- occurrence = lambda y_true , y_pred : int (y_pred == y_true ),
759
+ risk_occurrence = lambda y_true , y_pred : int (y_pred == y_true ),
760
+ risk_condition = lambda y_true , y_pred : True ,
752
761
higher_is_better = True ,
753
762
)
754
763
755
764
recall = BinaryClassificationRisk (
756
- occurrence = lambda y_true , y_pred : None if y_true == 0 else int (y_pred == y_true ),
765
+ risk_occurrence = lambda y_true , y_pred : int (y_pred == y_true ),
766
+ risk_condition = lambda y_true , y_pred : y_true == 1 ,
757
767
higher_is_better = True ,
758
768
)
0 commit comments