66from joblib import Parallel , delayed
77from sklearn .base import RegressorMixin , clone
88from sklearn .model_selection import BaseCrossValidator
9- from sklearn .utils import _safe_indexing
9+ from sklearn .utils import _safe_indexing , deprecated
1010from sklearn .utils .validation import _num_samples , check_is_fitted
1111
1212from mapie ._typing import ArrayLike , NDArray
1313from mapie .aggregation_functions import aggregate_all , phi2D
14- from mapie .estimator .interface import EnsembleEstimator
1514from mapie .utils import (check_nan_in_aposteriori_prediction , check_no_agg_cv ,
1615 fit_estimator )
1716
1817
19- class EnsembleRegressor ( EnsembleEstimator ) :
18+ class EnsembleRegressor :
2019 """
2120 This class implements methods to handle the training and usage of the
2221 estimator. This estimator can be unique or composed by cross validated
@@ -409,6 +408,11 @@ def predict_calib(
409408
410409 return y_pred
411410
411+ @deprecated (
412+ "WARNING: EnsembleRegressor.fit is deprecated."
413+ "Instead use EnsembleRegressor.fit_single_estimator"
414+ "then EnsembleRegressor.fit_multi_estimators"
415+ )
412416 def fit (
413417 self ,
414418 X : ArrayLike ,
@@ -451,42 +455,60 @@ def fit(
451455 EnsembleRegressor
452456 The estimator fitted.
453457 """
454- # Initialization
455- single_estimator_ : RegressorMixin
456- estimators_ : List [RegressorMixin ] = []
457- full_indexes = np .arange (_num_samples (X ))
458- cv = self .cv
459- self .use_split_method_ = check_no_agg_cv (X , self .cv , self .no_agg_cv_ )
460- estimator = self .estimator
458+ self .fit_single_estimator (
459+ X ,
460+ y ,
461+ sample_weight ,
462+ groups ,
463+ ** fit_params
464+ )
465+
466+ self .fit_multi_estimators (
467+ X ,
468+ y ,
469+ sample_weight ,
470+ groups ,
471+ ** fit_params
472+ )
473+
474+ return self
475+
476+ def fit_multi_estimators (
477+ self ,
478+ X : ArrayLike ,
479+ y : ArrayLike ,
480+ sample_weight : Optional [ArrayLike ] = None ,
481+ groups : Optional [ArrayLike ] = None ,
482+ ** fit_params
483+ ) -> EnsembleRegressor :
484+
461485 n_samples = _num_samples (y )
486+ estimators : List [RegressorMixin ] = []
462487
463- # Computation
464- if cv == "prefit" :
465- single_estimator_ = estimator
488+ if self .cv == "prefit" :
489+
490+ # Create a placeholder attribute 'k_' filled with NaN values
491+ # This attribute is defined for consistency but
492+ # is not used in prefit mode
466493 self .k_ = np .full (
467494 shape = (n_samples , 1 ), fill_value = np .nan , dtype = float
468495 )
496+
469497 else :
470- single_estimator_ = self ._fit_oof_estimator (
471- clone (estimator ),
472- X ,
473- y ,
474- full_indexes ,
475- sample_weight ,
476- ** fit_params
477- )
478- cv = cast (BaseCrossValidator , cv )
498+ cv = cast (BaseCrossValidator , self .cv )
479499 self .k_ = np .full (
480500 shape = (n_samples , cv .get_n_splits (X , y , groups )),
481501 fill_value = np .nan ,
482502 dtype = float ,
483503 )
484- if self .method == "naive" :
485- estimators_ = [single_estimator_ ]
486- else :
487- estimators_ = Parallel (self .n_jobs , verbose = self .verbose )(
504+
505+ if self .method != "naive" :
506+ estimators = Parallel (
507+ self .n_jobs ,
508+ verbose = self .verbose
509+ )(
488510 delayed (self ._fit_oof_estimator )(
489- clone (estimator ),
511+ clone (self . estimator ),
490512 X ,
491513 y ,
492514 train_index ,
@@ -495,13 +517,47 @@ def fit(
495517 )
496518 for train_index , _ in cv .split (X , y , groups )
497519 )
498- # In split-CP, we keep only the model fitted on train dataset
499- if self .use_split_method_ :
500- single_estimator_ = estimators_ [0 ]
501520
502- self .single_estimator_ = single_estimator_
503- self .estimators_ = estimators_
521+ self .estimators_ = estimators
522+
523+ return self
524+
525+ def fit_single_estimator (
526+ self ,
527+ X : ArrayLike ,
528+ y : ArrayLike ,
529+ sample_weight : Optional [ArrayLike ] = None ,
530+ groups : Optional [ArrayLike ] = None ,
531+ ** fit_params
532+ ) -> EnsembleRegressor :
533+
534+ self .use_split_method_ = check_no_agg_cv (X , self .cv , self .no_agg_cv_ )
535+ single_estimator_ : RegressorMixin
536+
537+ if self .cv == "prefit" :
538+ single_estimator_ = self .estimator
539+ else :
540+ cv = cast (BaseCrossValidator , self .cv )
541+ if self .use_split_method_ :
542+ train_indexes = [
543+ train_index for train_index , test_index in cv .split (
544+ X , y , groups )
545+ ][0 ]
546+ indexes = train_indexes
547+ else :
548+ full_indexes = np .arange (_num_samples (X ))
549+ indexes = full_indexes
550+
551+ single_estimator_ = self ._fit_oof_estimator (
552+ clone (self .estimator ),
553+ X ,
554+ y ,
555+ indexes ,
556+ sample_weight ,
557+ ** fit_params
558+ )
504559
560+ self .single_estimator_ = single_estimator_
505561 return self
506562
507563 def predict (
0 commit comments