Skip to content

Commit 12d1e02

Browse files
Revert changes other than URL fix
1 parent 7c90adb commit 12d1e02

File tree

1 file changed

+65
-53
lines changed

1 file changed

+65
-53
lines changed

mapie/risk_control.py

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@
1111
from sklearn.multioutput import MultiOutputClassifier
1212
from sklearn.pipeline import Pipeline
1313
from sklearn.utils import check_random_state
14-
from sklearn.utils.validation import (
15-
_check_y,
16-
_num_samples,
17-
check_is_fitted,
18-
indexable,
19-
)
14+
from sklearn.utils.validation import (_check_y, _num_samples, check_is_fitted,
15+
indexable)
2016

2117
from numpy.typing import ArrayLike, NDArray
2218
from .control_risk.crc_rcps import find_lambda_star, get_r_hat_plus
@@ -157,25 +153,30 @@ class PrecisionRecallController(BaseEstimator, ClassifierMixin):
157153
[False True False]
158154
[False True False]]
159155
"""
160-
161-
valid_methods_by_metric_ = {"precision": ["ltt"], "recall": ["rcps", "crc"]}
156+
valid_methods_by_metric_ = {
157+
"precision": ["ltt"],
158+
"recall": ["rcps", "crc"]
159+
}
162160
valid_methods = list(chain(*valid_methods_by_metric_.values()))
163161
valid_metric_ = list(valid_methods_by_metric_.keys())
164162
valid_bounds_ = ["hoeffding", "bernstein", "wsr", None]
165163
lambdas = np.arange(0, 1, 0.01)
166164
n_lambdas = len(lambdas)
167-
fit_attributes = ["single_estimator_", "risks"]
165+
fit_attributes = [
166+
"single_estimator_",
167+
"risks"
168+
]
168169
sigma_init = 0.25 # Value given in the paper [1]
169-
cal_size = 0.3
170+
cal_size = .3
170171

171172
def __init__(
172173
self,
173174
estimator: Optional[ClassifierMixin] = None,
174-
metric_control: Optional[str] = "recall",
175+
metric_control: Optional[str] = 'recall',
175176
method: Optional[str] = None,
176177
n_jobs: Optional[int] = None,
177178
random_state: Optional[Union[int, np.random.RandomState]] = None,
178-
verbose: int = 0,
179+
verbose: int = 0
179180
) -> None:
180181
self.estimator = estimator
181182
self.metric_control = metric_control
@@ -210,18 +211,16 @@ def _check_method(self) -> None:
210211
self.method = cast(str, self.method)
211212
self.metric_control = cast(str, self.metric_control)
212213

213-
if (
214-
self.method
215-
not in self.valid_methods_by_metric_[self.metric_control]
216-
):
214+
if self.method not in self.valid_methods_by_metric_[
215+
self.metric_control
216+
]:
217217
raise ValueError(
218218
"Invalid method for metric: "
219-
+ "You are controlling "
220-
+ self.metric_control
221-
+ " and you are using invalid method: "
222-
+ self.method
223-
+ ". Use instead: "
224-
+ "".join(self.valid_methods_by_metric_[self.metric_control])
219+
+ "You are controlling " + self.metric_control
220+
+ " and you are using invalid method: " + self.method
221+
+ ". Use instead: " + "".join(self.valid_methods_by_metric_[
222+
self.metric_control]
223+
)
225224
)
226225

227226
def _check_all_labelled(self, y: NDArray) -> None:
@@ -242,7 +241,9 @@ def _check_all_labelled(self, y: NDArray) -> None:
242241
"""
243242
if not (y.sum(axis=1) > 0).all():
244243
raise ValueError(
245-
"Invalid y. All observations should contain at least one label."
244+
"Invalid y. "
245+
"All observations should contain at "
246+
"least one label."
246247
)
247248

248249
def _check_delta(self, delta: Optional[float]):
@@ -267,7 +268,8 @@ def _check_delta(self, delta: Optional[float]):
267268
"""
268269
if (not isinstance(delta, float)) and (delta is not None):
269270
raise ValueError(
270-
f"Invalid delta. delta must be a float, not a {type(delta)}"
271+
"Invalid delta. "
272+
"delta must be a float, not a {type(delta)}"
271273
)
272274
if (self.method == "rcps") or (self.method == "ltt"):
273275
if delta is None:
@@ -276,8 +278,11 @@ def _check_delta(self, delta: Optional[float]):
276278
"delta cannot be ``None`` when controlling "
277279
"Recall with RCPS or Precision with LTT"
278280
)
279-
elif (delta <= 0) or (delta >= 1):
280-
raise ValueError("Invalid delta. delta must be in ]0, 1[")
281+
elif ((delta <= 0) or (delta >= 1)):
282+
raise ValueError(
283+
"Invalid delta. "
284+
"delta must be in ]0, 1["
285+
)
281286
if (self.method == "crc") and (delta is not None):
282287
warnings.warn(
283288
"WARNING: you are using crc method, hence "
@@ -297,8 +302,7 @@ def _check_valid_index(self, alpha: NDArray):
297302
if self.valid_index[i] == []:
298303
warnings.warn(
299304
"Warning: LTT method has returned an empty sequence"
300-
+ " for alpha="
301-
+ str(alpha[i])
305+
+ " for alpha=" + str(alpha[i])
302306
)
303307

304308
def _check_estimator(
@@ -357,12 +361,14 @@ def _check_estimator(
357361
"use partial_fit."
358362
)
359363
if (estimator is None) and (_refit):
360-
estimator = MultiOutputClassifier(LogisticRegression())
364+
estimator = MultiOutputClassifier(
365+
LogisticRegression()
366+
)
361367
X_train, X_conf, y_train, y_conf = train_test_split(
362-
X,
363-
y,
364-
test_size=self.conformalize_size,
365-
random_state=self.random_state,
368+
X,
369+
y,
370+
test_size=self.conformalize_size,
371+
random_state=self.random_state,
366372
)
367373
estimator.fit(X_train, y_train)
368374
warnings.warn(
@@ -454,7 +460,8 @@ def _check_metric_control(self):
454460
self.method = "ltt"
455461

456462
def _transform_pred_proba(
457-
self, y_pred_proba: Union[Sequence[NDArray], NDArray]
463+
self,
464+
y_pred_proba: Union[Sequence[NDArray], NDArray]
458465
) -> NDArray:
459466
"""If the output of the predict_proba is a list of arrays (output of
460467
the ``predict_proba`` of ``MultiOutputClassifier``) we transform it
@@ -476,7 +483,7 @@ def _transform_pred_proba(
476483
else:
477484
y_pred_proba_stacked = np.stack(
478485
y_pred_proba, # type: ignore
479-
axis=0,
486+
axis=0
480487
)[:, :, 1]
481488
y_pred_proba_array = np.moveaxis(y_pred_proba_stacked, 0, -1)
482489

@@ -519,7 +526,10 @@ def partial_fit(
519526

520527
X, y = indexable(X, y)
521528
_check_y(y, multi_output=True)
522-
estimator, X, y = self._check_estimator(X, y, self.estimator, _refit)
529+
estimator, X, y = self._check_estimator(
530+
X, y, self.estimator,
531+
_refit
532+
)
523533

524534
y = cast(NDArray, y)
525535
X = cast(NDArray, X)
@@ -551,11 +561,15 @@ def partial_fit(
551561
y_pred_proba_array = self._transform_pred_proba(y_pred_proba)
552562
if self.metric_control == "recall":
553563
partial_risk = compute_risk_recall(
554-
self.lambdas, y_pred_proba_array, y
564+
self.lambdas,
565+
y_pred_proba_array,
566+
y
555567
)
556568
else: # self.metric_control == "precision"
557569
partial_risk = compute_risk_precision(
558-
self.lambdas, y_pred_proba_array, y
570+
self.lambdas,
571+
y_pred_proba_array,
572+
y
559573
)
560574
self.risks = np.concatenate([self.risks, partial_risk], axis=0)
561575

@@ -565,7 +579,7 @@ def fit(
565579
self,
566580
X: ArrayLike,
567581
y: ArrayLike,
568-
conformalize_size: Optional[float] = 0.3,
582+
conformalize_size: Optional[float] = .3,
569583
) -> PrecisionRecallController:
570584
"""
571585
Fit the base estimator or use the fitted base estimator.
@@ -597,7 +611,7 @@ def predict(
597611
X: ArrayLike,
598612
alpha: Optional[Union[float, Iterable[float]]] = None,
599613
delta: Optional[float] = None,
600-
bound: Optional[Union[str, None]] = None,
614+
bound: Optional[Union[str, None]] = None
601615
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
602616
"""
603617
Prediction sets on new samples based on target confidence
@@ -660,37 +674,35 @@ def predict(
660674

661675
y_pred_proba_array = self._transform_pred_proba(y_pred_proba)
662676
y_pred_proba_array = np.repeat(
663-
y_pred_proba_array, len(alpha_np), axis=2
677+
y_pred_proba_array,
678+
len(alpha_np),
679+
axis=2
664680
)
665-
if self.metric_control == "precision":
681+
if self.metric_control == 'precision':
666682
self.n_obs = len(self.risks)
667683
self.r_hat = self.risks.mean(axis=0)
668684
self.valid_index, self.p_values = ltt_procedure(
669685
self.r_hat, alpha_np, delta, self.n_obs
670686
)
671687
self._check_valid_index(alpha_np)
672688
self.lambdas_star, self.r_star = find_lambda_control_star(
673-
self.r_hat, self.valid_index, self.lambdas
689+
self.r_hat, self.valid_index, self.lambdas
674690
)
675691
y_pred_proba_array = (
676-
y_pred_proba_array
677-
> np.array(self.lambdas_star)[np.newaxis, np.newaxis, :]
692+
y_pred_proba_array >
693+
np.array(self.lambdas_star)[np.newaxis, np.newaxis, :]
678694
)
679695

680696
else:
681697
self.r_hat, self.r_hat_plus = get_r_hat_plus(
682-
self.risks,
683-
self.lambdas,
684-
self.method,
685-
bound,
686-
delta,
687-
self.sigma_init,
698+
self.risks, self.lambdas, self.method,
699+
bound, delta, self.sigma_init
688700
)
689701
self.lambdas_star = find_lambda_star(
690702
self.lambdas, self.r_hat_plus, alpha_np
691703
)
692704
y_pred_proba_array = (
693-
y_pred_proba_array
694-
> self.lambdas_star[np.newaxis, np.newaxis, :]
705+
y_pred_proba_array >
706+
self.lambdas_star[np.newaxis, np.newaxis, :]
695707
)
696708
return y_pred, y_pred_proba_array

0 commit comments

Comments
 (0)