Skip to content

Commit 9bed02c

Browse files
author
Vianney Taquet
committed
Take second comments into account
1 parent 8e56bdb commit 9bed02c

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

mapie/estimators.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def _check_cv(
301301

302302
def _check_alpha(
303303
self,
304-
alpha: Union[float, Iterable[float]]
304+
alpha: Optional[Union[float, Iterable[float]]] = None
305305
) -> np.ndarray:
306306
"""
307307
Check alpha and prepare it as a np.ndarray
@@ -325,28 +325,31 @@ def _check_alpha(
325325
ValueError
326326
If alpha is not a float or an Iterable of floats between 0 and 1.
327327
"""
328-
if isinstance(alpha, float):
329-
alpha_np = np.array([alpha])
330-
elif isinstance(alpha, Iterable):
331-
alpha_np = np.array(alpha)
328+
if alpha is None:
329+
return np.zeros(1)
332330
else:
333-
raise ValueError(
334-
"Invalid alpha. Allowed values are float or Iterable."
335-
)
336-
if len(alpha_np.shape) != 1:
337-
raise ValueError(
338-
"Invalid alpha. "
339-
"Please provide a one-dimensional list of values."
340-
)
341-
if alpha_np.dtype.type not in [np.float64, np.float32]:
342-
raise ValueError(
343-
"Invalid alpha. Allowed values are Iterable of floats."
344-
)
345-
if np.any((alpha_np <= 0) | (alpha_np >= 1)):
346-
raise ValueError(
347-
"Invalid alpha. Allowed values are between 0 and 1."
348-
)
349-
return alpha_np
331+
if isinstance(alpha, float):
332+
alpha_np = np.array([alpha])
333+
elif isinstance(alpha, Iterable):
334+
alpha_np = np.array(alpha)
335+
else:
336+
raise ValueError(
337+
"Invalid alpha. Allowed values are float or Iterable."
338+
)
339+
if len(alpha_np.shape) != 1:
340+
raise ValueError(
341+
"Invalid alpha. "
342+
"Please provide a one-dimensional list of values."
343+
)
344+
if alpha_np.dtype.type not in [np.float64, np.float32]:
345+
raise ValueError(
346+
"Invalid alpha. Allowed values are Iterable of floats."
347+
)
348+
if np.any((alpha_np <= 0) | (alpha_np >= 1)):
349+
raise ValueError(
350+
"Invalid alpha. Allowed values are between 0 and 1."
351+
)
352+
return alpha_np
350353

351354
def _check_n_features_in(
352355
self,
@@ -578,6 +581,7 @@ def predict(
578581
"residuals_"
579582
]
580583
)
584+
alpha_ = self._check_alpha(alpha)
581585
X = check_array(X, force_all_finite=False, dtype=["float64", "object"])
582586
y_pred = self.single_estimator_.predict(X)
583587

@@ -589,7 +593,6 @@ def predict(
589593
# (n_alpha,) : alpha
590594
# (n_samples_test, n_alpha) : y_pred_low, y_pred_up
591595
# (n_samples_test, n_samples_train) : y_pred_multi, low/up_bounds
592-
alpha_ = self._check_alpha(alpha)
593596
if self.method in ["naive", "base"] or self.cv == "prefit":
594597
quantile = np.quantile(
595598
self.residuals_, 1 - alpha_, interpolation="higher"

0 commit comments

Comments
 (0)