|
11 | 11 | from sklearn.multioutput import MultiOutputClassifier
|
12 | 12 | from sklearn.pipeline import Pipeline
|
13 | 13 | from sklearn.utils import check_random_state
|
14 |
| -from sklearn.utils.validation import ( |
15 |
| - _check_y, _num_samples, check_is_fitted, |
16 |
| - indexable, |
17 |
| -) |
| 14 | +from sklearn.utils.validation import (_check_y, _num_samples, check_is_fitted, |
| 15 | + indexable) |
18 | 16 |
|
19 | 17 | from numpy.typing import ArrayLike, NDArray
|
20 | 18 | from .control_risk.crc_rcps import find_lambda_star, get_r_hat_plus
|
@@ -220,9 +218,8 @@ def _check_method(self) -> None:
|
220 | 218 | "Invalid method for metric: "
|
221 | 219 | + "You are controlling " + self.metric_control
|
222 | 220 | + " and you are using invalid method: " + self.method
|
223 |
| - + ". Use instead: " + "".join( |
224 |
| - self.valid_methods_by_metric_[ |
225 |
| - self.metric_control] |
| 221 | + + ". Use instead: " + "".join(self.valid_methods_by_metric_[ |
| 222 | + self.metric_control] |
226 | 223 | )
|
227 | 224 | )
|
228 | 225 |
|
@@ -368,10 +365,10 @@ def _check_estimator(
|
368 | 365 | LogisticRegression()
|
369 | 366 | )
|
370 | 367 | X_train, X_conf, y_train, y_conf = train_test_split(
|
371 |
| - X, |
372 |
| - y, |
373 |
| - test_size=self.conformalize_size, |
374 |
| - random_state=self.random_state, |
| 368 | + X, |
| 369 | + y, |
| 370 | + test_size=self.conformalize_size, |
| 371 | + random_state=self.random_state, |
375 | 372 | )
|
376 | 373 | estimator.fit(X_train, y_train)
|
377 | 374 | warnings.warn(
|
@@ -689,7 +686,7 @@ def predict(
|
689 | 686 | )
|
690 | 687 | self._check_valid_index(alpha_np)
|
691 | 688 | self.lambdas_star, self.r_star = find_lambda_control_star(
|
692 |
| - self.r_hat, self.valid_index, self.lambdas |
| 689 | + self.r_hat, self.valid_index, self.lambdas |
693 | 690 | )
|
694 | 691 | y_pred_proba_array = (
|
695 | 692 | y_pred_proba_array >
|
|
0 commit comments