Skip to content

Commit 428f4f3

Browse files
committed
Merge branch 'master' into 491-add-predict-params-into-classification-files
2 parents 16d54d6 + 603b5da commit 428f4f3

File tree

7 files changed

+209
-34
lines changed

7 files changed

+209
-34
lines changed

mapie/estimator/classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def predict(
454454
self,
455455
X: ArrayLike,
456456
agg_scores: Optional[str] = None,
457-
**predict_params
457+
**predict_params,
458458
) -> NDArray:
459459
"""
460460
Predict target from X. It also computes the prediction per train sample

mapie/estimator/regressor.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def _predict_oof_estimator(
233233
estimator: RegressorMixin,
234234
X: ArrayLike,
235235
val_index: ArrayLike,
236+
**predict_params
236237
) -> Tuple[NDArray, ArrayLike]:
237238
"""
238239
Perform predictions on a single out-of-fold model on a validation set.
@@ -248,14 +249,17 @@ def _predict_oof_estimator(
248249
val_index: ArrayLike of shape (n_samples_val)
249250
Validation data indices.
250251
252+
**predict_params : dict
253+
Additional predict parameters.
254+
251255
Returns
252256
-------
253257
Tuple[NDArray, ArrayLike]
254258
Predictions of estimator from val_index of X.
255259
"""
256260
X_val = _safe_indexing(X, val_index)
257261
if _num_samples(X_val) > 0:
258-
y_pred = estimator.predict(X_val)
262+
y_pred = estimator.predict(X_val, **predict_params)
259263
else:
260264
y_pred = np.array([])
261265
return y_pred, val_index
@@ -306,7 +310,7 @@ def _aggregate_with_mask(
306310
else:
307311
raise ValueError("The value of self.agg_function is not correct")
308312

309-
def _pred_multi(self, X: ArrayLike) -> NDArray:
313+
def _pred_multi(self, X: ArrayLike, **predict_params) -> NDArray:
310314
"""
311315
Return a prediction per train sample for each test sample, by
312316
aggregation with matrix ``k_``.
@@ -316,12 +320,15 @@ def _pred_multi(self, X: ArrayLike) -> NDArray:
316320
X: ArrayLike of shape (n_samples_test, n_features)
317321
Input data
318322
323+
**predict_params : dict
324+
Additional predict parameters.
325+
319326
Returns
320327
-------
321328
NDArray of shape (n_samples_test, n_samples_train)
322329
"""
323330
y_pred_multi = np.column_stack(
324-
[e.predict(X) for e in self.estimators_]
331+
[e.predict(X, **predict_params) for e in self.estimators_]
325332
)
326333
# At this point, y_pred_multi is of shape
327334
# (n_samples_test, n_estimators_). The method
@@ -334,7 +341,8 @@ def predict_calib(
334341
self,
335342
X: ArrayLike,
336343
y: Optional[ArrayLike] = None,
337-
groups: Optional[ArrayLike] = None
344+
groups: Optional[ArrayLike] = None,
345+
**predict_params
338346
) -> NDArray:
339347
"""
340348
Perform predictions on X : the calibration set.
@@ -355,6 +363,9 @@ def predict_calib(
355363
356364
By default ``None``.
357365
366+
**predict_params : dict
367+
Additional predict parameters.
368+
358369
Returns
359370
-------
360371
NDArray of shape (n_samples_test, 1)
@@ -371,7 +382,7 @@ def predict_calib(
371382
cv = cast(BaseCrossValidator, self.cv)
372383
outputs = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
373384
delayed(self._predict_oof_estimator)(
374-
estimator, X, calib_index,
385+
estimator, X, calib_index, **predict_params
375386
)
376387
for (_, calib_index), estimator in zip(
377388
cv.split(X, y, groups),
@@ -404,7 +415,7 @@ def fit(
404415
y: ArrayLike,
405416
sample_weight: Optional[ArrayLike] = None,
406417
groups: Optional[ArrayLike] = None,
407-
**fit_params,
418+
**fit_params
408419
) -> EnsembleRegressor:
409420
"""
410421
Fit the base estimator under the ``single_estimator_`` attribute.
@@ -526,6 +537,9 @@ def predict(
526537
predictions (3 arrays). If ``False`` the method return the
527538
simple predictions only.
528539
540+
**predict_params : dict
541+
Additional predict parameters.
542+
529543
Returns
530544
-------
531545
Tuple[NDArray, NDArray, NDArray]
@@ -535,15 +549,15 @@ def predict(
535549
"""
536550
check_is_fitted(self, self.fit_attributes)
537551

538-
y_pred = self.single_estimator_.predict(X)
552+
y_pred = self.single_estimator_.predict(X, **predict_params)
539553
if not return_multi_pred and not ensemble:
540554
return y_pred
541555

542556
if self.method in self.no_agg_methods_ or self.use_split_method_:
543557
y_pred_multi_low = y_pred[:, np.newaxis]
544558
y_pred_multi_up = y_pred[:, np.newaxis]
545559
else:
546-
y_pred_multi = self._pred_multi(X)
560+
y_pred_multi = self._pred_multi(X, **predict_params)
547561

548562
if self.method == "minmax":
549563
y_pred_multi_low = np.min(y_pred_multi, axis=1, keepdims=True)

mapie/regression/quantile_regression.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,7 @@ def predict(
649649
optimize_beta: bool = False,
650650
allow_infinite_bounds: bool = False,
651651
symmetry: Optional[bool] = True,
652+
**predict_params,
652653
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
653654
"""
654655
Predict target on new samples with confidence intervals.
@@ -676,6 +677,9 @@ def predict(
676677
each residuals separatly or to use the maximum of the two
677678
combined.
678679
680+
predict_params : dict
681+
Additional predict parameters.
682+
679683
Returns
680684
-------
681685
Union[NDArray, Tuple[NDArray, NDArray]]
@@ -699,7 +703,7 @@ def predict(
699703
dtype=float,
700704
)
701705
for i, est in enumerate(self.estimators_):
702-
y_preds[i] = est.predict(X)
706+
y_preds[i] = est.predict(X, **predict_params)
703707
check_lower_upper_bounds(y_preds[0], y_preds[1], y_preds[2])
704708
if symmetry:
705709
quantile = np.full(

mapie/regression/regression.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Iterable, Optional, Tuple, Union, cast
4+
from typing import Any, Iterable, Optional, Tuple, Union, cast
55

66
import numpy as np
77
from sklearn.base import BaseEstimator, RegressorMixin
@@ -19,7 +19,8 @@
1919
from mapie.utils import (check_alpha, check_alpha_and_n_samples,
2020
check_cv, check_estimator_fit_predict,
2121
check_n_features_in, check_n_jobs, check_null_weight,
22-
check_verbose, get_effective_calibration_samples)
22+
check_verbose, get_effective_calibration_samples,
23+
check_predict_params)
2324

2425

2526
class MapieRegressor(BaseEstimator, RegressorMixin):
@@ -471,7 +472,7 @@ def fit(
471472
y: ArrayLike,
472473
sample_weight: Optional[ArrayLike] = None,
473474
groups: Optional[ArrayLike] = None,
474-
**fit_params,
475+
**kwargs: Any
475476
) -> MapieRegressor:
476477
"""
477478
Fit estimator and compute conformity scores used for
@@ -504,14 +505,21 @@ def fit(
504505
train/test set.
505506
By default ``None``.
506507
507-
**fit_params : dict
508-
Additional fit parameters.
508+
kwargs : dict
509+
Additional fit and predict parameters.
509510
510511
Returns
511512
-------
512513
MapieRegressor
513514
The model itself.
514515
"""
516+
fit_params = kwargs.pop('fit_params', {})
517+
predict_params = kwargs.pop('predict_params', {})
518+
if len(predict_params) > 0:
519+
self._predict_params = True
520+
else:
521+
self._predict_params = False
522+
515523
# Checks
516524
(estimator,
517525
self.conformity_score_function_,
@@ -538,7 +546,9 @@ def fit(
538546
)
539547

540548
# Predict on calibration data
541-
y_pred = self.estimator_.predict_calib(X, y=y, groups=groups)
549+
y_pred = self.estimator_.predict_calib(
550+
X, y=y, groups=groups, **predict_params
551+
)
542552

543553
# Compute the conformity scores (manage jk-ab case)
544554
self.conformity_scores_ = \
@@ -555,6 +565,7 @@ def predict(
555565
alpha: Optional[Union[float, Iterable[float]]] = None,
556566
optimize_beta: bool = False,
557567
allow_infinite_bounds: bool = False,
568+
**predict_params
558569
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
559570
"""
560571
Predict target on new samples with confidence intervals.
@@ -604,6 +615,9 @@ def predict(
604615
605616
By default ``False``.
606617
618+
predict_params : dict
619+
Additional predict parameters.
620+
607621
Returns
608622
-------
609623
Union[NDArray, Tuple[NDArray, NDArray]]
@@ -614,14 +628,16 @@ def predict(
614628
- [:, 1, :]: Upper bound of the prediction interval.
615629
"""
616630
# Checks
631+
if hasattr(self, '_predict_params'):
632+
check_predict_params(self._predict_params, predict_params, self.cv)
617633
check_is_fitted(self, self.fit_attributes)
618634
self._check_ensemble(ensemble)
619635
alpha = cast(Optional[NDArray], check_alpha(alpha))
620636

621637
# If alpha is None, predict the target without confidence intervals
622638
if alpha is None:
623639
y_pred = self.estimator_.predict(
624-
X, ensemble, return_multi_pred=False
640+
X, ensemble, return_multi_pred=False, **predict_params
625641
)
626642
return np.array(y_pred)
627643

mapie/regression/time_series_regression.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def predict(
407407
alpha: Optional[Union[float, Iterable[float]]] = None,
408408
optimize_beta: bool = False,
409409
allow_infinite_bounds: bool = False,
410+
**predict_params
410411
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
411412
"""
412413
Predict target on new samples with confidence intervals.
@@ -441,6 +442,9 @@ def predict(
441442
allow_infinite_bounds: bool
442443
Allow infinite prediction intervals to be produced.
443444
445+
predict_params : dict
446+
Additional predict parameters.
447+
444448
Returns
445449
-------
446450
Union[NDArray, Tuple[NDArray, NDArray]]
@@ -452,15 +456,16 @@ def predict(
452456
"""
453457
if alpha is None:
454458
super().predict(
455-
X, ensemble=ensemble, alpha=alpha, optimize_beta=optimize_beta
459+
X, ensemble=ensemble, alpha=alpha, optimize_beta=optimize_beta,
460+
**predict_params
456461
)
457462

458463
if self.method == "aci":
459464
alpha = self._get_alpha(alpha)
460465

461466
return super().predict(
462467
X, ensemble=ensemble, alpha=alpha, optimize_beta=optimize_beta,
463-
allow_infinite_bounds=allow_infinite_bounds
468+
allow_infinite_bounds=allow_infinite_bounds, **predict_params
464469
)
465470

466471
def _more_tags(self):

0 commit comments

Comments
 (0)