Skip to content

Commit 662adad

Browse files
FIX: type checking from PR #566 (#567)
1 parent c39946f commit 662adad

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

mapie/regression/quantile_regression.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Iterable, Dict, List, Optional, Tuple, Union, cast
4+
from typing import Iterable, List, Optional, Tuple, Union, cast, Any
55

66
import numpy as np
77
from sklearn.base import RegressorMixin, clone
@@ -547,7 +547,7 @@ def fit(
547547
The model itself.
548548
"""
549549

550-
self.init_fit()
550+
self.initialize_fit()
551551

552552
if self.cv == "prefit":
553553
X_calib, y_calib = self.prefit_estimators(X, y)
@@ -570,8 +570,7 @@ def fit(
570570

571571
return self
572572

573-
def init_fit(self):
574-
573+
def initialize_fit(self) -> None:
575574
self.cv = self._check_cv(cast(str, self.cv))
576575
self.alpha_np = self._check_alpha(self.alpha)
577576
self.estimators_: List[RegressorMixin] = []
@@ -667,29 +666,31 @@ def fit_estimators(
667666

668667
def conformalize(
669668
self,
670-
X_conf: ArrayLike,
671-
y_conf: ArrayLike,
669+
X: ArrayLike,
670+
y: ArrayLike,
672671
sample_weight: Optional[ArrayLike] = None,
673-
predict_params: Dict = {},
674-
):
672+
# Parameter groups kept for compliance with superclass MapieRegressor
673+
groups: Optional[ArrayLike] = None,
674+
**kwargs: Any,
675+
) -> MapieRegressor:
675676

676-
self.n_calib_samples = _num_samples(y_conf)
677+
self.n_calib_samples = _num_samples(y)
677678

678679
y_calib_preds = np.full(
679680
shape=(3, self.n_calib_samples),
680681
fill_value=np.nan
681682
)
682683

683684
for i, est in enumerate(self.estimators_):
684-
y_calib_preds[i] = est.predict(X_conf, **predict_params).ravel()
685+
y_calib_preds[i] = est.predict(X, **kwargs).ravel()
685686

686687
self.conformity_scores_ = np.full(
687688
shape=(3, self.n_calib_samples),
688689
fill_value=np.nan
689690
)
690691

691-
self.conformity_scores_[0] = y_calib_preds[0] - y_conf
692-
self.conformity_scores_[1] = y_conf - y_calib_preds[1]
692+
self.conformity_scores_[0] = y_calib_preds[0] - y
693+
self.conformity_scores_[1] = y - y_calib_preds[1]
693694
self.conformity_scores_[2] = np.max(
694695
[
695696
self.conformity_scores_[0],

0 commit comments

Comments
 (0)