@@ -301,7 +301,7 @@ def _check_cv(
301301
302302 def _check_alpha (
303303 self ,
304- alpha : Union [float , Iterable [float ]]
304+ alpha : Optional [ Union [float , Iterable [float ]]] = None
305305 ) -> np .ndarray :
306306 """
307307 Check alpha and prepare it as a np.ndarray
@@ -325,28 +325,31 @@ def _check_alpha(
325325 ValueError
326326 If alpha is not a float or an Iterable of floats between 0 and 1.
327327 """
328- if isinstance (alpha , float ):
329- alpha_np = np .array ([alpha ])
330- elif isinstance (alpha , Iterable ):
331- alpha_np = np .array (alpha )
328+ if alpha is None :
329+ return np .zeros (1 )
332330 else :
333- raise ValueError (
334- "Invalid alpha. Allowed values are float or Iterable."
335- )
336- if len (alpha_np .shape ) != 1 :
337- raise ValueError (
338- "Invalid alpha. "
339- "Please provide a one-dimensional list of values."
340- )
341- if alpha_np .dtype .type not in [np .float64 , np .float32 ]:
342- raise ValueError (
343- "Invalid alpha. Allowed values are Iterable of floats."
344- )
345- if np .any ((alpha_np <= 0 ) | (alpha_np >= 1 )):
346- raise ValueError (
347- "Invalid alpha. Allowed values are between 0 and 1."
348- )
349- return alpha_np
331+ if isinstance (alpha , float ):
332+ alpha_np = np .array ([alpha ])
333+ elif isinstance (alpha , Iterable ):
334+ alpha_np = np .array (alpha )
335+ else :
336+ raise ValueError (
337+ "Invalid alpha. Allowed values are float or Iterable."
338+ )
339+ if len (alpha_np .shape ) != 1 :
340+ raise ValueError (
341+ "Invalid alpha. "
342+ "Please provide a one-dimensional list of values."
343+ )
344+ if alpha_np .dtype .type not in [np .float64 , np .float32 ]:
345+ raise ValueError (
346+ "Invalid alpha. Allowed values are Iterable of floats."
347+ )
348+ if np .any ((alpha_np <= 0 ) | (alpha_np >= 1 )):
349+ raise ValueError (
350+ "Invalid alpha. Allowed values are between 0 and 1."
351+ )
352+ return alpha_np
350353
351354 def _check_n_features_in (
352355 self ,
@@ -578,6 +581,7 @@ def predict(
578581 "residuals_"
579582 ]
580583 )
584+ alpha_ = self ._check_alpha (alpha )
581585 X = check_array (X , force_all_finite = False , dtype = ["float64" , "object" ])
582586 y_pred = self .single_estimator_ .predict (X )
583587
@@ -589,7 +593,6 @@ def predict(
589593 # (n_alpha,) : alpha
590594 # (n_samples_test, n_alpha) : y_pred_low, y_pred_up
591595 # (n_samples_test, n_samples_train) : y_pred_multi, low/up_bounds
592- alpha_ = self ._check_alpha (alpha )
593596 if self .method in ["naive" , "base" ] or self .cv == "prefit" :
594597 quantile = np .quantile (
595598 self .residuals_ , 1 - alpha_ , interpolation = "higher"
0 commit comments