Skip to content

Commit 7c90adb

Browse files
gmartinonQMFaustinPulveric
authored andcommitted
FIX RCPS URL
1 parent 44144de commit 7c90adb

File tree

1 file changed

+54
-66
lines changed

1 file changed

+54
-66
lines changed

mapie/risk_control.py

Lines changed: 54 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
from sklearn.multioutput import MultiOutputClassifier
1212
from sklearn.pipeline import Pipeline
1313
from sklearn.utils import check_random_state
14-
from sklearn.utils.validation import (_check_y, _num_samples, check_is_fitted,
15-
indexable)
14+
from sklearn.utils.validation import (
15+
_check_y,
16+
_num_samples,
17+
check_is_fitted,
18+
indexable,
19+
)
1620

1721
from numpy.typing import ArrayLike, NDArray
1822
from .control_risk.crc_rcps import find_lambda_star, get_r_hat_plus
@@ -127,7 +131,7 @@ class PrecisionRecallController(BaseEstimator, ClassifierMixin):
127131
[1] Lihua Lei Jitendra Malik Stephen Bates, Anastasios Angelopoulos
128132
and Michael I. Jordan. Distribution-free, risk-controlling prediction
129133
sets. CoRR, abs/2101.02703, 2021.
130-
URL https://arxiv.org/abs/2101.02703.39
134+
URL https://arxiv.org/abs/2101.02703
131135
132136
[2] Angelopoulos, Anastasios N., Stephen, Bates, Adam, Fisch, Lihua,
133137
Lei, and Tal, Schuster. "Conformal Risk Control." (2022).
@@ -153,30 +157,25 @@ class PrecisionRecallController(BaseEstimator, ClassifierMixin):
153157
[False True False]
154158
[False True False]]
155159
"""
156-
valid_methods_by_metric_ = {
157-
"precision": ["ltt"],
158-
"recall": ["rcps", "crc"]
159-
}
160+
161+
valid_methods_by_metric_ = {"precision": ["ltt"], "recall": ["rcps", "crc"]}
160162
valid_methods = list(chain(*valid_methods_by_metric_.values()))
161163
valid_metric_ = list(valid_methods_by_metric_.keys())
162164
valid_bounds_ = ["hoeffding", "bernstein", "wsr", None]
163165
lambdas = np.arange(0, 1, 0.01)
164166
n_lambdas = len(lambdas)
165-
fit_attributes = [
166-
"single_estimator_",
167-
"risks"
168-
]
167+
fit_attributes = ["single_estimator_", "risks"]
169168
sigma_init = 0.25 # Value given in the paper [1]
170-
cal_size = .3
169+
cal_size = 0.3
171170

172171
def __init__(
173172
self,
174173
estimator: Optional[ClassifierMixin] = None,
175-
metric_control: Optional[str] = 'recall',
174+
metric_control: Optional[str] = "recall",
176175
method: Optional[str] = None,
177176
n_jobs: Optional[int] = None,
178177
random_state: Optional[Union[int, np.random.RandomState]] = None,
179-
verbose: int = 0
178+
verbose: int = 0,
180179
) -> None:
181180
self.estimator = estimator
182181
self.metric_control = metric_control
@@ -211,16 +210,18 @@ def _check_method(self) -> None:
211210
self.method = cast(str, self.method)
212211
self.metric_control = cast(str, self.metric_control)
213212

214-
if self.method not in self.valid_methods_by_metric_[
215-
self.metric_control
216-
]:
213+
if (
214+
self.method
215+
not in self.valid_methods_by_metric_[self.metric_control]
216+
):
217217
raise ValueError(
218218
"Invalid method for metric: "
219-
+ "You are controlling " + self.metric_control
220-
+ " and you are using invalid method: " + self.method
221-
+ ". Use instead: " + "".join(self.valid_methods_by_metric_[
222-
self.metric_control]
223-
)
219+
+ "You are controlling "
220+
+ self.metric_control
221+
+ " and you are using invalid method: "
222+
+ self.method
223+
+ ". Use instead: "
224+
+ "".join(self.valid_methods_by_metric_[self.metric_control])
224225
)
225226

226227
def _check_all_labelled(self, y: NDArray) -> None:
@@ -241,9 +242,7 @@ def _check_all_labelled(self, y: NDArray) -> None:
241242
"""
242243
if not (y.sum(axis=1) > 0).all():
243244
raise ValueError(
244-
"Invalid y. "
245-
"All observations should contain at "
246-
"least one label."
245+
"Invalid y. All observations should contain at least one label."
247246
)
248247

249248
def _check_delta(self, delta: Optional[float]):
@@ -268,8 +267,7 @@ def _check_delta(self, delta: Optional[float]):
268267
"""
269268
if (not isinstance(delta, float)) and (delta is not None):
270269
raise ValueError(
271-
"Invalid delta. "
272-
f"delta must be a float, not a {type(delta)}"
270+
f"Invalid delta. delta must be a float, not a {type(delta)}"
273271
)
274272
if (self.method == "rcps") or (self.method == "ltt"):
275273
if delta is None:
@@ -278,11 +276,8 @@ def _check_delta(self, delta: Optional[float]):
278276
"delta cannot be ``None`` when controlling "
279277
"Recall with RCPS or Precision with LTT"
280278
)
281-
elif ((delta <= 0) or (delta >= 1)):
282-
raise ValueError(
283-
"Invalid delta. "
284-
"delta must be in ]0, 1["
285-
)
279+
elif (delta <= 0) or (delta >= 1):
280+
raise ValueError("Invalid delta. delta must be in ]0, 1[")
286281
if (self.method == "crc") and (delta is not None):
287282
warnings.warn(
288283
"WARNING: you are using crc method, hence "
@@ -302,7 +297,8 @@ def _check_valid_index(self, alpha: NDArray):
302297
if self.valid_index[i] == []:
303298
warnings.warn(
304299
"Warning: LTT method has returned an empty sequence"
305-
+ " for alpha=" + str(alpha[i])
300+
+ " for alpha="
301+
+ str(alpha[i])
306302
)
307303

308304
def _check_estimator(
@@ -361,14 +357,12 @@ def _check_estimator(
361357
"use partial_fit."
362358
)
363359
if (estimator is None) and (_refit):
364-
estimator = MultiOutputClassifier(
365-
LogisticRegression()
366-
)
360+
estimator = MultiOutputClassifier(LogisticRegression())
367361
X_train, X_conf, y_train, y_conf = train_test_split(
368-
X,
369-
y,
370-
test_size=self.conformalize_size,
371-
random_state=self.random_state,
362+
X,
363+
y,
364+
test_size=self.conformalize_size,
365+
random_state=self.random_state,
372366
)
373367
estimator.fit(X_train, y_train)
374368
warnings.warn(
@@ -460,8 +454,7 @@ def _check_metric_control(self):
460454
self.method = "ltt"
461455

462456
def _transform_pred_proba(
463-
self,
464-
y_pred_proba: Union[Sequence[NDArray], NDArray]
457+
self, y_pred_proba: Union[Sequence[NDArray], NDArray]
465458
) -> NDArray:
466459
"""If the output of the predict_proba is a list of arrays (output of
467460
the ``predict_proba`` of ``MultiOutputClassifier``) we transform it
@@ -483,7 +476,7 @@ def _transform_pred_proba(
483476
else:
484477
y_pred_proba_stacked = np.stack(
485478
y_pred_proba, # type: ignore
486-
axis=0
479+
axis=0,
487480
)[:, :, 1]
488481
y_pred_proba_array = np.moveaxis(y_pred_proba_stacked, 0, -1)
489482

@@ -526,10 +519,7 @@ def partial_fit(
526519

527520
X, y = indexable(X, y)
528521
_check_y(y, multi_output=True)
529-
estimator, X, y = self._check_estimator(
530-
X, y, self.estimator,
531-
_refit
532-
)
522+
estimator, X, y = self._check_estimator(X, y, self.estimator, _refit)
533523

534524
y = cast(NDArray, y)
535525
X = cast(NDArray, X)
@@ -561,15 +551,11 @@ def partial_fit(
561551
y_pred_proba_array = self._transform_pred_proba(y_pred_proba)
562552
if self.metric_control == "recall":
563553
partial_risk = compute_risk_recall(
564-
self.lambdas,
565-
y_pred_proba_array,
566-
y
554+
self.lambdas, y_pred_proba_array, y
567555
)
568556
else: # self.metric_control == "precision"
569557
partial_risk = compute_risk_precision(
570-
self.lambdas,
571-
y_pred_proba_array,
572-
y
558+
self.lambdas, y_pred_proba_array, y
573559
)
574560
self.risks = np.concatenate([self.risks, partial_risk], axis=0)
575561

@@ -579,7 +565,7 @@ def fit(
579565
self,
580566
X: ArrayLike,
581567
y: ArrayLike,
582-
conformalize_size: Optional[float] = .3
568+
conformalize_size: Optional[float] = 0.3,
583569
) -> PrecisionRecallController:
584570
"""
585571
Fit the base estimator or use the fitted base estimator.
@@ -611,7 +597,7 @@ def predict(
611597
X: ArrayLike,
612598
alpha: Optional[Union[float, Iterable[float]]] = None,
613599
delta: Optional[float] = None,
614-
bound: Optional[Union[str, None]] = None
600+
bound: Optional[Union[str, None]] = None,
615601
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
616602
"""
617603
Prediction sets on new samples based on target confidence
@@ -674,35 +660,37 @@ def predict(
674660

675661
y_pred_proba_array = self._transform_pred_proba(y_pred_proba)
676662
y_pred_proba_array = np.repeat(
677-
y_pred_proba_array,
678-
len(alpha_np),
679-
axis=2
663+
y_pred_proba_array, len(alpha_np), axis=2
680664
)
681-
if self.metric_control == 'precision':
665+
if self.metric_control == "precision":
682666
self.n_obs = len(self.risks)
683667
self.r_hat = self.risks.mean(axis=0)
684668
self.valid_index, self.p_values = ltt_procedure(
685669
self.r_hat, alpha_np, delta, self.n_obs
686670
)
687671
self._check_valid_index(alpha_np)
688672
self.lambdas_star, self.r_star = find_lambda_control_star(
689-
self.r_hat, self.valid_index, self.lambdas
673+
self.r_hat, self.valid_index, self.lambdas
690674
)
691675
y_pred_proba_array = (
692-
y_pred_proba_array >
693-
np.array(self.lambdas_star)[np.newaxis, np.newaxis, :]
676+
y_pred_proba_array
677+
> np.array(self.lambdas_star)[np.newaxis, np.newaxis, :]
694678
)
695679

696680
else:
697681
self.r_hat, self.r_hat_plus = get_r_hat_plus(
698-
self.risks, self.lambdas, self.method,
699-
bound, delta, self.sigma_init
682+
self.risks,
683+
self.lambdas,
684+
self.method,
685+
bound,
686+
delta,
687+
self.sigma_init,
700688
)
701689
self.lambdas_star = find_lambda_star(
702690
self.lambdas, self.r_hat_plus, alpha_np
703691
)
704692
y_pred_proba_array = (
705-
y_pred_proba_array >
706-
self.lambdas_star[np.newaxis, np.newaxis, :]
693+
y_pred_proba_array
694+
> self.lambdas_star[np.newaxis, np.newaxis, :]
707695
)
708696
return y_pred, y_pred_proba_array

0 commit comments

Comments
 (0)