Skip to content

Commit 065921c

Browse files
ENH: simplify BinaryClassificationRisk API
1 parent 95d1ce6 commit 065921c

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

mapie/risk_control.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -714,11 +714,12 @@ class BinaryClassificationRisk:
714714
# Take the example of precision in the docstring to explain how the class works.
715715
def __init__(
716716
self,
717-
occurrence: Callable[[int, int], Optional[int]],
718-
# (y_true, y_pred), output: int (0 or 1) or None if undefined
717+
risk_occurrence: Callable[[int, int], int],
718+
risk_condition: Callable[[int, int], bool],
719719
higher_is_better: bool,
720720
):
721-
self.occurrence = occurrence
721+
self.risk_occurrence = risk_occurrence
722+
self.risk_condition = risk_condition
722723
self.higher_is_better = higher_is_better
723724

724725
def get_value_and_effective_sample_size(
@@ -728,31 +729,40 @@ def get_value_and_effective_sample_size(
728729
) -> Optional[Tuple[float, int]]:
729730
# float between 0 and 1, int between 0 and len(y_true)
730731
risk_occurrences = [
731-
self.occurrence(y_true_i, y_pred_i)
732+
self.risk_occurrence(y_true_i, y_pred_i)
732733
for y_true_i, y_pred_i in zip(y_true, y_pred)
733734
]
734-
effective_sample_size = len(y_true) - risk_occurrences.count(None)
735+
risk_conditions = [
736+
self.risk_condition(y_true_i, y_pred_i)
737+
for y_true_i, y_pred_i in zip(y_true, y_pred)
738+
]
739+
effective_sample_size = len(y_true) - risk_conditions.count(False)
735740
if effective_sample_size != 0:
736-
risk_value = sum(
737-
occurrence for occurrence in risk_occurrences if occurrence is not None
738-
) / effective_sample_size
741+
risk_sum = sum(
742+
risk_occurrence for risk_occurrence, risk_condition
743+
in zip(risk_occurrences, risk_conditions)
744+
if risk_condition)
745+
risk_value = risk_sum / effective_sample_size
739746
if self.higher_is_better:
740747
risk_value = 1 - risk_value
741748
return risk_value, effective_sample_size
742749
return None
743750

744751

745752
precision = BinaryClassificationRisk(
746-
occurrence=lambda y_true, y_pred: None if y_pred == 0 else int(y_pred == y_true),
753+
risk_occurrence=lambda y_true, y_pred: int(y_pred == y_true),
754+
risk_condition=lambda y_true, y_pred: y_pred == 1,
747755
higher_is_better=True,
748756
)
749757

750758
accuracy = BinaryClassificationRisk(
751-
occurrence=lambda y_true, y_pred: int(y_pred == y_true),
759+
risk_occurrence=lambda y_true, y_pred: int(y_pred == y_true),
760+
risk_condition=lambda y_true, y_pred: True,
752761
higher_is_better=True,
753762
)
754763

755764
recall = BinaryClassificationRisk(
756-
occurrence=lambda y_true, y_pred: None if y_true == 0 else int(y_pred == y_true),
765+
risk_occurrence=lambda y_true, y_pred: int(y_pred == y_true),
766+
risk_condition=lambda y_true, y_pred: y_true == 1,
757767
higher_is_better=True,
758768
)

0 commit comments

Comments
 (0)