@@ -346,75 +346,45 @@ def _check_cv(
346346 "Invalid cv method, only valid method is ``split``."
347347 )
348348
349- def _check_calib_set (
349+ def _train_calib_split (
350350 self ,
351351 X : ArrayLike ,
352352 y : ArrayLike ,
353353 sample_weight : Optional [ArrayLike ] = None ,
354- X_calib : Optional [ArrayLike ] = None ,
355- y_calib : Optional [ArrayLike ] = None ,
356354 calib_size : Optional [float ] = 0.3 ,
357355 random_state : Optional [Union [int , np .random .RandomState , None ]] = None ,
358356 shuffle : Optional [bool ] = True ,
359357 stratify : Optional [ArrayLike ] = None ,
360358 ) -> Tuple [
361359 ArrayLike , ArrayLike , ArrayLike , ArrayLike , Optional [ArrayLike ]
362360 ]:
363- """
364- Check if a calibration set has already been defined, if not, then
365- we define one using the ``train_test_split`` method.
366-
367- Parameters
368- ----------
369- Same definition of parameters as for the ``fit`` method.
370-
371- Returns
372- -------
373- Tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike, ArrayLike]
374- - [0]: ArrayLike of shape (n_samples_*(1-calib_size), n_features)
375- X_train
376- - [1]: ArrayLike of shape (n_samples_*(1-calib_size),)
377- y_train
378- - [2]: ArrayLike of shape (n_samples_*calib_size, n_features)
379- X_calib
380- - [3]: ArrayLike of shape (n_samples_*calib_size,)
381- y_calib
382- - [4]: ArrayLike of shape (n_samples_,)
383- sample_weight_train
384- """
385- if X_calib is None or y_calib is None :
386- if sample_weight is None :
387- X_train , X_calib , y_train , y_calib = train_test_split (
388- X ,
389- y ,
390- test_size = calib_size ,
391- random_state = random_state ,
392- shuffle = shuffle ,
393- stratify = stratify
394- )
395- sample_weight_train = sample_weight
396- else :
397- (
398- X_train ,
399- X_calib ,
400- y_train ,
401- y_calib ,
402- sample_weight_train ,
403- _ ,
404- ) = train_test_split (
405- X ,
406- y ,
407- sample_weight ,
408- test_size = calib_size ,
409- random_state = random_state ,
410- shuffle = shuffle ,
411- stratify = stratify
412- )
361+ if sample_weight is None :
362+ X_train , X_calib , y_train , y_calib = train_test_split (
363+ X ,
364+ y ,
365+ test_size = calib_size ,
366+ random_state = random_state ,
367+ shuffle = shuffle ,
368+ stratify = stratify
369+ )
370+ sample_weight_train = sample_weight
413371 else :
414- X_train , y_train , sample_weight_train = X , y , sample_weight
415- X_train , X_calib = cast (ArrayLike , X_train ), cast (ArrayLike , X_calib )
416- y_train , y_calib = cast (ArrayLike , y_train ), cast (ArrayLike , y_calib )
417- sample_weight_train = cast (ArrayLike , sample_weight_train )
372+ (
373+ X_train ,
374+ X_calib ,
375+ y_train ,
376+ y_calib ,
377+ sample_weight_train ,
378+ _ ,
379+ ) = train_test_split (
380+ X ,
381+ y ,
382+ sample_weight ,
383+ test_size = calib_size ,
384+ random_state = random_state ,
385+ shuffle = shuffle ,
386+ stratify = stratify
387+ )
418388 return X_train , y_train , X_calib , y_calib , sample_weight_train
419389
420390 def _check_prefit_params (
@@ -546,13 +516,12 @@ def fit(
546516 MapieQuantileRegressor
547517 The model itself.
548518 """
549-
550- self .initialize_fit ()
519+ self ._initialize_fit_conformalize ()
551520
552521 if self .cv == "prefit" :
553- X_calib , y_calib = self . prefit_estimators ( X , y )
522+ X_calib , y_calib = X , y
554523 else :
555- X_calib , y_calib = self .fit_estimators (
524+ result = self ._prepare_train_calib (
556525 X = X ,
557526 y = y ,
558527 sample_weight = sample_weight ,
@@ -563,33 +532,31 @@ def fit(
563532 random_state = random_state ,
564533 shuffle = shuffle ,
565534 stratify = stratify ,
566- ** fit_params ,
535+ )
536+ X_train , y_train , X_calib , y_calib , sample_weight = result
537+ self ._fit_estimators (
538+ X = X_train ,
539+ y = y_train ,
540+ sample_weight = sample_weight ,
541+ ** fit_params
567542 )
568543
569544 self .conformalize (X_calib , y_calib )
570545
571546 return self
572547
573- def initialize_fit (self ) -> None :
548+ def _initialize_fit_conformalize (self ) -> None :
574549 self .cv = self ._check_cv (cast (str , self .cv ))
575550 self .alpha_np = self ._check_alpha (self .alpha )
576551 self .estimators_ : List [RegressorMixin ] = []
577552
578- def prefit_estimators (
579- self ,
580- X : ArrayLike ,
581- y : ArrayLike
582- ) -> Tuple [ArrayLike , ArrayLike ]:
583-
553+ def _initialize_and_check_prefit_estimators (self ) -> None :
584554 estimator = cast (List , self .estimator )
585555 self ._check_prefit_params (estimator )
586556 self .estimators_ = list (estimator )
587557 self .single_estimator_ = self .estimators_ [2 ]
588558
589- X_calib , y_calib = indexable (X , y )
590- return X_calib , y_calib
591-
592- def fit_estimators (
559+ def _prepare_train_calib (
593560 self ,
594561 X : ArrayLike ,
595562 y : ArrayLike ,
@@ -601,68 +568,81 @@ def fit_estimators(
601568 random_state : Optional [Union [int , np .random .RandomState ]] = None ,
602569 shuffle : Optional [bool ] = True ,
603570 stratify : Optional [ArrayLike ] = None ,
604- ** fit_params ,
605- ) -> Tuple [ArrayLike , ArrayLike ]:
606-
571+ ) -> Tuple [
572+ ArrayLike , ArrayLike , ArrayLike , ArrayLike , Optional [ArrayLike ]
573+ ]:
574+ """
575+ Handles the preparation of training and calibration datasets,
576+ including validation and splitting.
577+ Returns: X_train, y_train, X_calib, y_calib, sample_weight_train
578+ """
607579 self ._check_parameters ()
608- checked_estimator = self ._check_estimator (self .estimator )
609580 random_state = check_random_state (random_state )
610581 X , y = indexable (X , y )
611582
612- results = self ._check_calib_set (
613- X ,
614- y ,
615- sample_weight ,
616- X_calib ,
617- y_calib ,
618- calib_size ,
619- random_state ,
620- shuffle ,
621- stratify ,
622- )
583+ if X_calib is None or y_calib is None :
584+ return self ._train_calib_split (
585+ X ,
586+ y ,
587+ sample_weight ,
588+ calib_size ,
589+ random_state ,
590+ shuffle ,
591+ stratify
592+ )
593+ else :
594+ return X , y , X_calib , y_calib , sample_weight
623595
624- X_train , y_train , X_calib , y_calib , sample_weight_train = results
625- X_train , y_train = indexable (X_train , y_train )
626- X_calib , y_calib = indexable (X_calib , y_calib )
627- y_train , y_calib = _check_y (y_train ), _check_y (y_calib )
628- self .n_calib_samples = _num_samples (y_calib )
629- check_alpha_and_n_samples (self .alpha , self .n_calib_samples )
630- sample_weight_train , X_train , y_train = check_null_weight (
631- sample_weight_train ,
632- X_train ,
633- y_train
596+ # Second function: Handles estimator fitting
597+ def _fit_estimators (
598+ self ,
599+ X : ArrayLike ,
600+ y : ArrayLike ,
601+ sample_weight : Optional [ArrayLike ] = None ,
602+ ** fit_params
603+ ) -> None :
604+ """
605+ Fits the estimators with provided training data
606+ and stores them in self.estimators_.
607+ """
608+ checked_estimator = self ._check_estimator (self .estimator )
609+
610+ X , y = indexable (X , y )
611+ y = _check_y (y )
612+
613+ sample_weight , X , y = check_null_weight (
614+ sample_weight , X , y
634615 )
635- y_train = cast (NDArray , y_train )
636616
637617 if isinstance (checked_estimator , Pipeline ):
638618 estimator = checked_estimator [- 1 ]
639619 else :
640620 estimator = checked_estimator
621+
641622 name_estimator = estimator .__class__ .__name__
642- alpha_name = self .quantile_estimator_params [
643- name_estimator
644- ]["alpha_name" ]
623+ alpha_name = self .quantile_estimator_params [name_estimator ][
624+ "alpha_name"
625+ ]
626+
645627 for i , alpha_ in enumerate (self .alpha_np ):
646628 cloned_estimator_ = clone (checked_estimator )
647629 params = {alpha_name : alpha_ }
648630 if isinstance (checked_estimator , Pipeline ):
649631 cloned_estimator_ [- 1 ].set_params (** params )
650632 else :
651633 cloned_estimator_ .set_params (** params )
652- self .estimators_ .append (fit_estimator (
653- cloned_estimator_ ,
654- X_train ,
655- y_train ,
656- sample_weight_train ,
657- ** fit_params ,
634+
635+ self .estimators_ .append (
636+ fit_estimator (
637+ cloned_estimator_ ,
638+ X ,
639+ y ,
640+ sample_weight ,
641+ ** fit_params ,
658642 )
659643 )
660- self .single_estimator_ = self .estimators_ [2 ]
661-
662- X_calib = cast (ArrayLike , X_calib )
663- y_calib = cast (ArrayLike , y_calib )
664644
665- return X_calib , y_calib
645+ self . single_estimator_ = self . estimators_ [ 2 ]
666646
667647 def conformalize (
668648 self ,
@@ -673,24 +653,31 @@ def conformalize(
673653 groups : Optional [ArrayLike ] = None ,
674654 ** kwargs : Any ,
675655 ) -> MapieRegressor :
656+ if self .cv == "prefit" :
657+ self ._initialize_and_check_prefit_estimators ()
676658
677- self .n_calib_samples = _num_samples (y )
659+ X_calib , y_calib = cast (ArrayLike , X ), cast (ArrayLike , y )
660+ X_calib , y_calib = indexable (X_calib , y_calib )
661+ y_calib = _check_y (y_calib )
662+
663+ self .n_calib_samples = _num_samples (y_calib )
664+ check_alpha_and_n_samples (self .alpha , self .n_calib_samples )
678665
679666 y_calib_preds = np .full (
680667 shape = (3 , self .n_calib_samples ),
681668 fill_value = np .nan
682669 )
683670
684671 for i , est in enumerate (self .estimators_ ):
685- y_calib_preds [i ] = est .predict (X , ** kwargs ).ravel ()
672+ y_calib_preds [i ] = est .predict (X_calib , ** kwargs ).ravel ()
686673
687674 self .conformity_scores_ = np .full (
688675 shape = (3 , self .n_calib_samples ),
689676 fill_value = np .nan
690677 )
691678
692- self .conformity_scores_ [0 ] = y_calib_preds [0 ] - y
693- self .conformity_scores_ [1 ] = y - y_calib_preds [1 ]
679+ self .conformity_scores_ [0 ] = y_calib_preds [0 ] - y_calib
680+ self .conformity_scores_ [1 ] = y_calib - y_calib_preds [1 ]
694681 self .conformity_scores_ [2 ] = np .max (
695682 [
696683 self .conformity_scores_ [0 ],
0 commit comments