Skip to content

Commit 8e56bdb

Browse files
author
Vianney Taquet
committed
Take first comments into account
1 parent cfbe8f5 commit 8e56bdb

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

mapie/estimators.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def _check_cv(
301301

302302
def _check_alpha(
303303
self,
304-
alpha: Optional[Union[float, Iterable[float]]] = None
304+
alpha: Union[float, Iterable[float]]
305305
) -> np.ndarray:
306306
"""
307307
Check alpha and prepare it as a np.ndarray
@@ -325,30 +325,27 @@ def _check_alpha(
325325
ValueError
326326
If alpha is not a float or an Iterable of floats between 0 and 1.
327327
"""
328-
if alpha is None:
329-
alpha_np = None
328+
if isinstance(alpha, float):
329+
alpha_np = np.array([alpha])
330+
elif isinstance(alpha, Iterable):
331+
alpha_np = np.array(alpha)
330332
else:
331-
if isinstance(alpha, float):
332-
alpha_np = np.array([alpha])
333-
elif isinstance(alpha, Iterable):
334-
alpha_np = np.array(alpha)
335-
else:
336-
raise ValueError(
337-
"Invalid alpha. Allowed values are float or Iterable."
338-
)
339-
if len(alpha_np.shape) != 1:
340-
raise ValueError(
341-
"Invalid alpha. "
342-
"Please provide a one-dimensional list of values."
343-
)
344-
if alpha_np.dtype.type not in [np.float64, np.float32]:
345-
raise ValueError(
346-
"Invalid alpha. Allowed values are Iterable of floats."
347-
)
348-
if np.any((alpha_np <= 0) | (alpha_np >= 1)):
349-
raise ValueError(
350-
"Invalid alpha. Allowed values are between 0 and 1."
351-
)
333+
raise ValueError(
334+
"Invalid alpha. Allowed values are float or Iterable."
335+
)
336+
if len(alpha_np.shape) != 1:
337+
raise ValueError(
338+
"Invalid alpha. "
339+
"Please provide a one-dimensional list of values."
340+
)
341+
if alpha_np.dtype.type not in [np.float64, np.float32]:
342+
raise ValueError(
343+
"Invalid alpha. Allowed values are Iterable of floats."
344+
)
345+
if np.any((alpha_np <= 0) | (alpha_np >= 1)):
346+
raise ValueError(
347+
"Invalid alpha. Allowed values are between 0 and 1."
348+
)
352349
return alpha_np
353350

354351
def _check_n_features_in(
@@ -582,7 +579,6 @@ def predict(
582579
]
583580
)
584581
X = check_array(X, force_all_finite=False, dtype=["float64", "object"])
585-
alpha_ = self._check_alpha(alpha)
586582
y_pred = self.single_estimator_.predict(X)
587583

588584
if alpha is None:
@@ -593,6 +589,7 @@ def predict(
593589
# (n_alpha,) : alpha
594590
# (n_samples_test, n_alpha) : y_pred_low, y_pred_up
595591
# (n_samples_test, n_samples_train) : y_pred_multi, low/up_bounds
592+
alpha_ = self._check_alpha(alpha)
596593
if self.method in ["naive", "base"] or self.cv == "prefit":
597594
quantile = np.quantile(
598595
self.residuals_, 1 - alpha_, interpolation="higher"

0 commit comments

Comments
 (0)