Skip to content

Commit 77df567

Browse files
REFACTOR: Break down the fit method in MapieQuantileRegressor into mu… (#578)
* REFACTOR: Break down the fit method in MapieQuantileRegressor into multiple sub-methods.
1 parent dab4d47 commit 77df567

File tree

2 files changed

+110
-121
lines changed

2 files changed

+110
-121
lines changed

mapie/regression/quantile_regression.py

Lines changed: 104 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -346,75 +346,45 @@ def _check_cv(
346346
"Invalid cv method, only valid method is ``split``."
347347
)
348348

349-
def _check_calib_set(
349+
def _train_calib_split(
350350
self,
351351
X: ArrayLike,
352352
y: ArrayLike,
353353
sample_weight: Optional[ArrayLike] = None,
354-
X_calib: Optional[ArrayLike] = None,
355-
y_calib: Optional[ArrayLike] = None,
356354
calib_size: Optional[float] = 0.3,
357355
random_state: Optional[Union[int, np.random.RandomState, None]] = None,
358356
shuffle: Optional[bool] = True,
359357
stratify: Optional[ArrayLike] = None,
360358
) -> Tuple[
361359
ArrayLike, ArrayLike, ArrayLike, ArrayLike, Optional[ArrayLike]
362360
]:
363-
"""
364-
Check if a calibration set has already been defined, if not, then
365-
we define one using the ``train_test_split`` method.
366-
367-
Parameters
368-
----------
369-
Same definition of parameters as for the ``fit`` method.
370-
371-
Returns
372-
-------
373-
Tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike, ArrayLike]
374-
- [0]: ArrayLike of shape (n_samples_*(1-calib_size), n_features)
375-
X_train
376-
- [1]: ArrayLike of shape (n_samples_*(1-calib_size),)
377-
y_train
378-
- [2]: ArrayLike of shape (n_samples_*calib_size, n_features)
379-
X_calib
380-
- [3]: ArrayLike of shape (n_samples_*calib_size,)
381-
y_calib
382-
- [4]: ArrayLike of shape (n_samples_,)
383-
sample_weight_train
384-
"""
385-
if X_calib is None or y_calib is None:
386-
if sample_weight is None:
387-
X_train, X_calib, y_train, y_calib = train_test_split(
388-
X,
389-
y,
390-
test_size=calib_size,
391-
random_state=random_state,
392-
shuffle=shuffle,
393-
stratify=stratify
394-
)
395-
sample_weight_train = sample_weight
396-
else:
397-
(
398-
X_train,
399-
X_calib,
400-
y_train,
401-
y_calib,
402-
sample_weight_train,
403-
_,
404-
) = train_test_split(
405-
X,
406-
y,
407-
sample_weight,
408-
test_size=calib_size,
409-
random_state=random_state,
410-
shuffle=shuffle,
411-
stratify=stratify
412-
)
361+
if sample_weight is None:
362+
X_train, X_calib, y_train, y_calib = train_test_split(
363+
X,
364+
y,
365+
test_size=calib_size,
366+
random_state=random_state,
367+
shuffle=shuffle,
368+
stratify=stratify
369+
)
370+
sample_weight_train = sample_weight
413371
else:
414-
X_train, y_train, sample_weight_train = X, y, sample_weight
415-
X_train, X_calib = cast(ArrayLike, X_train), cast(ArrayLike, X_calib)
416-
y_train, y_calib = cast(ArrayLike, y_train), cast(ArrayLike, y_calib)
417-
sample_weight_train = cast(ArrayLike, sample_weight_train)
372+
(
373+
X_train,
374+
X_calib,
375+
y_train,
376+
y_calib,
377+
sample_weight_train,
378+
_,
379+
) = train_test_split(
380+
X,
381+
y,
382+
sample_weight,
383+
test_size=calib_size,
384+
random_state=random_state,
385+
shuffle=shuffle,
386+
stratify=stratify
387+
)
418388
return X_train, y_train, X_calib, y_calib, sample_weight_train
419389

420390
def _check_prefit_params(
@@ -546,13 +516,12 @@ def fit(
546516
MapieQuantileRegressor
547517
The model itself.
548518
"""
549-
550-
self.initialize_fit()
519+
self._initialize_fit_conformalize()
551520

552521
if self.cv == "prefit":
553-
X_calib, y_calib = self.prefit_estimators(X, y)
522+
X_calib, y_calib = X, y
554523
else:
555-
X_calib, y_calib = self.fit_estimators(
524+
result = self._prepare_train_calib(
556525
X=X,
557526
y=y,
558527
sample_weight=sample_weight,
@@ -563,33 +532,31 @@ def fit(
563532
random_state=random_state,
564533
shuffle=shuffle,
565534
stratify=stratify,
566-
**fit_params,
535+
)
536+
X_train, y_train, X_calib, y_calib, sample_weight = result
537+
self._fit_estimators(
538+
X=X_train,
539+
y=y_train,
540+
sample_weight=sample_weight,
541+
**fit_params
567542
)
568543

569544
self.conformalize(X_calib, y_calib)
570545

571546
return self
572547

573-
def initialize_fit(self) -> None:
548+
def _initialize_fit_conformalize(self) -> None:
574549
self.cv = self._check_cv(cast(str, self.cv))
575550
self.alpha_np = self._check_alpha(self.alpha)
576551
self.estimators_: List[RegressorMixin] = []
577552

578-
def prefit_estimators(
579-
self,
580-
X: ArrayLike,
581-
y: ArrayLike
582-
) -> Tuple[ArrayLike, ArrayLike]:
583-
553+
def _initialize_and_check_prefit_estimators(self) -> None:
584554
estimator = cast(List, self.estimator)
585555
self._check_prefit_params(estimator)
586556
self.estimators_ = list(estimator)
587557
self.single_estimator_ = self.estimators_[2]
588558

589-
X_calib, y_calib = indexable(X, y)
590-
return X_calib, y_calib
591-
592-
def fit_estimators(
559+
def _prepare_train_calib(
593560
self,
594561
X: ArrayLike,
595562
y: ArrayLike,
@@ -601,68 +568,81 @@ def fit_estimators(
601568
random_state: Optional[Union[int, np.random.RandomState]] = None,
602569
shuffle: Optional[bool] = True,
603570
stratify: Optional[ArrayLike] = None,
604-
**fit_params,
605-
) -> Tuple[ArrayLike, ArrayLike]:
606-
571+
) -> Tuple[
572+
ArrayLike, ArrayLike, ArrayLike, ArrayLike, Optional[ArrayLike]
573+
]:
574+
"""
575+
Handles the preparation of training and calibration datasets,
576+
including validation and splitting.
577+
Returns: X_train, y_train, X_calib, y_calib, sample_weight_train
578+
"""
607579
self._check_parameters()
608-
checked_estimator = self._check_estimator(self.estimator)
609580
random_state = check_random_state(random_state)
610581
X, y = indexable(X, y)
611582

612-
results = self._check_calib_set(
613-
X,
614-
y,
615-
sample_weight,
616-
X_calib,
617-
y_calib,
618-
calib_size,
619-
random_state,
620-
shuffle,
621-
stratify,
622-
)
583+
if X_calib is None or y_calib is None:
584+
return self._train_calib_split(
585+
X,
586+
y,
587+
sample_weight,
588+
calib_size,
589+
random_state,
590+
shuffle,
591+
stratify
592+
)
593+
else:
594+
return X, y, X_calib, y_calib, sample_weight
623595

624-
X_train, y_train, X_calib, y_calib, sample_weight_train = results
625-
X_train, y_train = indexable(X_train, y_train)
626-
X_calib, y_calib = indexable(X_calib, y_calib)
627-
y_train, y_calib = _check_y(y_train), _check_y(y_calib)
628-
self.n_calib_samples = _num_samples(y_calib)
629-
check_alpha_and_n_samples(self.alpha, self.n_calib_samples)
630-
sample_weight_train, X_train, y_train = check_null_weight(
631-
sample_weight_train,
632-
X_train,
633-
y_train
596+
# Second function: Handles estimator fitting
597+
def _fit_estimators(
598+
self,
599+
X: ArrayLike,
600+
y: ArrayLike,
601+
sample_weight: Optional[ArrayLike] = None,
602+
**fit_params
603+
) -> None:
604+
"""
605+
Fits the estimators with provided training data
606+
and stores them in self.estimators_.
607+
"""
608+
checked_estimator = self._check_estimator(self.estimator)
609+
610+
X, y = indexable(X, y)
611+
y = _check_y(y)
612+
613+
sample_weight, X, y = check_null_weight(
614+
sample_weight, X, y
634615
)
635-
y_train = cast(NDArray, y_train)
636616

637617
if isinstance(checked_estimator, Pipeline):
638618
estimator = checked_estimator[-1]
639619
else:
640620
estimator = checked_estimator
621+
641622
name_estimator = estimator.__class__.__name__
642-
alpha_name = self.quantile_estimator_params[
643-
name_estimator
644-
]["alpha_name"]
623+
alpha_name = self.quantile_estimator_params[name_estimator][
624+
"alpha_name"
625+
]
626+
645627
for i, alpha_ in enumerate(self.alpha_np):
646628
cloned_estimator_ = clone(checked_estimator)
647629
params = {alpha_name: alpha_}
648630
if isinstance(checked_estimator, Pipeline):
649631
cloned_estimator_[-1].set_params(**params)
650632
else:
651633
cloned_estimator_.set_params(**params)
652-
self.estimators_.append(fit_estimator(
653-
cloned_estimator_,
654-
X_train,
655-
y_train,
656-
sample_weight_train,
657-
**fit_params,
634+
635+
self.estimators_.append(
636+
fit_estimator(
637+
cloned_estimator_,
638+
X,
639+
y,
640+
sample_weight,
641+
**fit_params,
658642
)
659643
)
660-
self.single_estimator_ = self.estimators_[2]
661-
662-
X_calib = cast(ArrayLike, X_calib)
663-
y_calib = cast(ArrayLike, y_calib)
664644

665-
return X_calib, y_calib
645+
self.single_estimator_ = self.estimators_[2]
666646

667647
def conformalize(
668648
self,
@@ -673,24 +653,31 @@ def conformalize(
673653
groups: Optional[ArrayLike] = None,
674654
**kwargs: Any,
675655
) -> MapieRegressor:
656+
if self.cv == "prefit":
657+
self._initialize_and_check_prefit_estimators()
676658

677-
self.n_calib_samples = _num_samples(y)
659+
X_calib, y_calib = cast(ArrayLike, X), cast(ArrayLike, y)
660+
X_calib, y_calib = indexable(X_calib, y_calib)
661+
y_calib = _check_y(y_calib)
662+
663+
self.n_calib_samples = _num_samples(y_calib)
664+
check_alpha_and_n_samples(self.alpha, self.n_calib_samples)
678665

679666
y_calib_preds = np.full(
680667
shape=(3, self.n_calib_samples),
681668
fill_value=np.nan
682669
)
683670

684671
for i, est in enumerate(self.estimators_):
685-
y_calib_preds[i] = est.predict(X, **kwargs).ravel()
672+
y_calib_preds[i] = est.predict(X_calib, **kwargs).ravel()
686673

687674
self.conformity_scores_ = np.full(
688675
shape=(3, self.n_calib_samples),
689676
fill_value=np.nan
690677
)
691678

692-
self.conformity_scores_[0] = y_calib_preds[0] - y
693-
self.conformity_scores_[1] = y - y_calib_preds[1]
679+
self.conformity_scores_[0] = y_calib_preds[0] - y_calib
680+
self.conformity_scores_[1] = y_calib - y_calib_preds[1]
694681
self.conformity_scores_[2] = np.max(
695682
[
696683
self.conformity_scores_[0],

mapie/tests/test_quantile_regression.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,13 @@ def test_for_small_dataset() -> None:
470470
estimator=qt,
471471
alpha=0.1
472472
)
473+
X_calib_toy_small = X_calib_toy[:2]
474+
y_calib_toy_small = y_calib_toy[:2]
473475
mapie_reg.fit(
474-
np.array([1, 2, 3]),
475-
np.array([2, 2, 3]),
476-
X_calib=np.array([3, 5]),
477-
y_calib=np.array([2, 3])
476+
X_train_toy,
477+
y_train_toy,
478+
X_calib=X_calib_toy_small,
479+
y_calib=y_calib_toy_small
478480
)
479481

480482

0 commit comments

Comments
 (0)