Skip to content

Commit 56c1a25

Browse files
UPD: docstring + improved split managemenent + extension of tests to larger dataset (raps limitation)
1 parent 1d978ce commit 56c1a25

File tree

6 files changed

+334
-247
lines changed

6 files changed

+334
-247
lines changed

mapie/classification.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import numpy as np
77
from sklearn.base import BaseEstimator, ClassifierMixin
8-
from sklearn.model_selection import BaseCrossValidator, ShuffleSplit
8+
from sklearn.model_selection import (BaseCrossValidator, BaseShuffleSplit,
9+
StratifiedShuffleSplit)
910
from sklearn.preprocessing import LabelEncoder, label_binarize
1011
from sklearn.utils import _safe_indexing, check_random_state
1112
from sklearn.utils.multiclass import (check_classification_targets,
@@ -301,9 +302,9 @@ def _check_raps(self):
301302
ValueError
302303
If ``method`` is ``"raps"`` and ``cv`` is not ``"prefit"``.
303304
"""
304-
if (self.method == "raps") and (
305-
(self.cv not in self.raps_valid_cv_)
306-
or isinstance(self.cv, ShuffleSplit)
305+
if (self.method == "raps") and not (
306+
(self.cv in self.raps_valid_cv_)
307+
or isinstance(self.cv, BaseShuffleSplit)
307308
):
308309
raise ValueError(
309310
"RAPS method can only be used "
@@ -926,7 +927,7 @@ def _check_fit_parameter(
926927
y: ArrayLike
927928
Target values.
928929
929-
sample_weight: Optional[NDArray] of shape (n_samples,)
930+
sample_weight: Optional[ArrayLike] of shape (n_samples,)
930931
Non-null sample weights.
931932
932933
groups: Optional[ArrayLike] of shape (n_samples,)
@@ -940,8 +941,8 @@ def _check_fit_parameter(
940941
Optional[Union[int, str, BaseCrossValidator]],
941942
ArrayLike, NDArray, NDArray, Optional[NDArray],
942943
Optional[NDArray], ArrayLike]
943-
944944
Parameters checked
945+
945946
Raises
946947
------
947948
ValueError
@@ -952,7 +953,6 @@ def _check_fit_parameter(
952953
If ``cv`` is `"prefit"`` or ``"split"`` and ``method`` is not
953954
``"base"``.
954955
"""
955-
956956
self._check_parameters()
957957
cv = check_cv(
958958
self.cv, test_size=self.test_size, random_state=self.random_state
@@ -979,15 +979,15 @@ def _check_fit_parameter(
979979
self.label_encoder_ = enc
980980
self._check_target(y)
981981

982-
return (estimator, cv, X, y, y_enc, sample_weight, groups, n_samples)
982+
return estimator, cv, X, y, y_enc, sample_weight, groups, n_samples
983983

984984
def _split_data(
985985
self,
986-
X,
987-
y_enc,
988-
sample_weight,
989-
groups,
990-
size_raps
986+
X: ArrayLike,
987+
y_enc: ArrayLike,
988+
sample_weight: Optional[ArrayLike] = None,
989+
groups: Optional[ArrayLike] = None,
990+
size_raps: Optional[float] = None,
991991
):
992992
"""Split data for raps method
993993
@@ -999,7 +999,7 @@ def _split_data(
999999
y_enc: ArrayLike
10001000
Target values as normalized encodings.
10011001
1002-
sample_weight: Optional[NDArray] of shape (n_samples,)
1002+
sample_weight: Optional[ArrayLike] of shape (n_samples,)
10031003
Non-null sample weights.
10041004
10051005
groups: Optional[ArrayLike] of shape (n_samples,)
@@ -1015,34 +1015,38 @@ def _split_data(
10151015
-------
10161016
Tuple[NDArray, NDArray, NDArray, NDArray, Optional[NDArray],
10171017
Optional[NDArray]]
1018-
10191018
- NDArray of shape (n_samples, n_features)
10201019
- NDArray of shape (n_samples,)
10211020
- NDArray of shape (n_samples,)
10221021
- NDArray of shape (n_samples,)
10231022
- NDArray of shape (n_samples,)
10241023
- NDArray of shape (n_samples,)
10251024
"""
1026-
raps_split = ShuffleSplit(
1027-
1, test_size=size_raps, random_state=self.random_state
1025+
# Split data for raps method
1026+
raps_split = StratifiedShuffleSplit(
1027+
n_splits=1, test_size=size_raps, random_state=self.random_state
10281028
)
1029-
train_raps_index, val_raps_index = next(raps_split.split(X))
1029+
train_raps_index, val_raps_index = next(raps_split.split(X, y_enc))
10301030
X, self.X_raps, y_enc, self.y_raps = (
10311031
_safe_indexing(X, train_raps_index),
10321032
_safe_indexing(X, val_raps_index),
10331033
_safe_indexing(y_enc, train_raps_index),
10341034
_safe_indexing(y_enc, val_raps_index),
10351035
)
1036+
1037+
# Decode y_raps for use in the RAPS method
10361038
self.y_raps_no_enc = self.label_encoder_.inverse_transform(self.y_raps)
10371039
y = self.label_encoder_.inverse_transform(y_enc)
1040+
1041+
# Cast to NDArray for type checking
10381042
y_enc = cast(NDArray, y_enc)
10391043
n_samples = _num_samples(y_enc)
10401044
if sample_weight is not None:
1041-
sample_weight = sample_weight[train_raps_index]
10421045
sample_weight = cast(NDArray, sample_weight)
1046+
sample_weight = sample_weight[train_raps_index]
10431047
if groups is not None:
1044-
groups = groups[train_raps_index]
10451048
groups = cast(NDArray, groups)
1049+
groups = groups[train_raps_index]
10461050

10471051
return X, y_enc, y, n_samples, sample_weight, groups
10481052

@@ -1126,12 +1130,13 @@ def fit(
11261130
self.test_size,
11271131
self.verbose,
11281132
)
1129-
1133+
# Fit the prediction function
11301134
self.estimator_ = self.estimator_.fit(
11311135
X, y, y_enc=y_enc, sample_weight=sample_weight, groups=groups,
11321136
**fit_params
11331137
)
11341138

1139+
# Predict on calibration data
11351140
y_pred_proba, y, y_enc = self.estimator_.predict_proba_calib(
11361141
X, y, y_enc, groups
11371142
)
@@ -1176,10 +1181,6 @@ def fit(
11761181
"Invalid method. " f"Allowed values are {self.valid_methods_}."
11771182
)
11781183

1179-
# In split-CP, we keep only the model fitted on train dataset
1180-
if isinstance(cv, ShuffleSplit):
1181-
self.estimator_.single_estimator_ = self.estimator_.estimators_[0]
1182-
11831184
return self
11841185

11851186
def predict(
@@ -1278,9 +1279,12 @@ def predict(
12781279
alpha_np = cast(NDArray, alpha)
12791280
check_alpha_and_n_samples(alpha_np, n)
12801281

1281-
y_pred_proba = self.estimator_.predict(X, alpha_np, agg_scores)
1282-
# Check that sum of probas is equal to 1
1282+
y_pred_proba = self.estimator_.predict(X, agg_scores)
12831283
y_pred_proba = self._check_proba_normalized(y_pred_proba, axis=1)
1284+
if agg_scores != "crossval":
1285+
y_pred_proba = np.repeat(
1286+
y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2
1287+
)
12841288

12851289
# Choice of the quantile
12861290
if self.method == "naive":

0 commit comments

Comments
 (0)