Skip to content

Commit 5ba97d6

Browse files
author
Thibault Cordier
committed
UPD: use copy model to prefit
1 parent 2acb98d commit 5ba97d6

File tree

1 file changed

+42
-33
lines changed

1 file changed

+42
-33
lines changed

mapie/mondrian.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from copy import deepcopy
3+
from copy import copy
44
from typing import Iterable, Optional, Tuple, Union, cast
55

66
import numpy as np
@@ -148,21 +148,24 @@ def fit(
148148
self._check_group_length(X, groups)
149149
self.unique_groups = np.unique(groups)
150150
self.mapie_estimators = {}
151+
151152
if isinstance(self.mapie_estimator, MapieClassifier):
152153
self.n_classes = len(np.unique(y))
153154

154155
for group in self.unique_groups:
155-
mapie_group_estimator = deepcopy(self.mapie_estimator)
156+
mapie_group_estimator = copy(self.mapie_estimator)
156157
indices_groups = np.argwhere(groups == group)[:, 0]
157-
X_g, y_g = X[indices_groups], y[indices_groups]
158+
X_g = [X[index] for index in indices_groups]
159+
y_g = [y[index] for index in indices_groups]
158160
mapie_group_estimator.fit(X_g, y_g, **fit_params)
159161
self.mapie_estimators[group] = mapie_group_estimator
162+
160163
return self
161164

162165
def predict(
163166
self,
164167
X: ArrayLike,
165-
groups: ArrayLike,
168+
groups: Optional[ArrayLike] = None,
166169
alpha: Optional[Union[float, Iterable[float]]] = None,
167170
**predict_params
168171
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
@@ -174,9 +177,11 @@ def predict(
174177
X : ArrayLike of shape (n_samples, n_features)
175178
The input data
176179
177-
groups : ArrayLike of shape (n_samples,)
180+
groups : ArrayLike of shape (n_samples,), optional
178181
The groups of individuals. Must be defined by integers.
179182
183+
By default None.
184+
180185
alpha : float or Iterable[float], optional
181186
The desired coverage level(s) for each group.
182187
@@ -194,37 +199,35 @@ def predict(
194199
y_pss : NDArray of shape (n_samples, n_outputs, n_alpha)
195200
The predicted sets for the desired levels of coverage
196201
"""
197-
198202
check_is_fitted(self, self.fit_attributes)
199203
X = cast(NDArray, X)
200-
groups = self._check_groups_predict(X, groups)
201204
alpha_np = cast(NDArray, check_alpha(alpha))
205+
202206
if alpha_np is None and self.mapie_estimator.estimator is not None:
203207
return self.mapie_estimator.estimator.predict(
204208
X, **predict_params
205209
)
210+
211+
if isinstance(self.mapie_estimator, MapieClassifier):
212+
y_pred = np.empty((len(X), ))
213+
y_pss = np.empty((len(X), self.n_classes, len(alpha_np)))
206214
else:
207-
if isinstance(self.mapie_estimator, MapieClassifier):
208-
y_pred = np.empty(
209-
(X.shape[0], )
210-
)
211-
y_pss = np.empty(
212-
(X.shape[0], self.n_classes, len(alpha_np))
213-
)
214-
else:
215-
y_pred = np.empty((X.shape[0],))
216-
y_pss = np.empty((X.shape[0], 2, len(alpha_np)))
217-
unique_groups = np.unique(groups)
218-
for i, group in enumerate(unique_groups):
219-
indices_groups = np.argwhere(groups == group)[:, 0]
220-
X_g = X[indices_groups]
221-
y_pred_g, y_pss_g = self.mapie_estimators[group].predict(
222-
X_g, alpha=alpha_np, **predict_params
223-
)
224-
y_pred[indices_groups] = y_pred_g
225-
y_pss[indices_groups] = y_pss_g
226-
227-
return y_pred, y_pss
215+
y_pred = np.empty((len(X),))
216+
y_pss = np.empty((len(X), 2, len(alpha_np)))
217+
218+
groups = self._check_groups_predict(X, groups)
219+
unique_groups = np.unique(groups)
220+
221+
for _, group in enumerate(unique_groups):
222+
indices_groups = np.argwhere(groups == group)[:, 0]
223+
X_g = [X[index] for index in indices_groups]
224+
y_pred_g, y_pss_g = self.mapie_estimators[group].predict(
225+
X_g, alpha=alpha_np, **predict_params
226+
)
227+
y_pred[indices_groups] = y_pred_g
228+
y_pss[indices_groups] = y_pss_g
229+
230+
return y_pred, y_pss
228231

229232
def _check_cv(self):
230233
"""
@@ -263,9 +266,11 @@ def _check_groups_fit(self, X: NDArray, groups: NDArray):
263266
"""
264267
if not np.issubdtype(groups.dtype, np.integer):
265268
raise ValueError("The groups must be defined by integers")
269+
266270
_, counts = np.unique(groups, return_counts=True)
267271
if np.min(counts) < 2:
268272
raise ValueError("There must be at least 2 individuals per group")
273+
269274
self._check_group_length(X, groups)
270275

271276
def _check_groups_predict(self, X: NDArray, groups: ArrayLike) -> NDArray:
@@ -295,9 +300,10 @@ def _check_groups_predict(self, X: NDArray, groups: ArrayLike) -> NDArray:
295300
groups = cast(NDArray, np.array(groups))
296301
if not np.all(np.isin(groups, self.unique_groups)):
297302
raise ValueError(
298-
"There is at least one new group in the prediction"
303+
"There is at least one new group in the prediction."
299304
)
300305
self._check_group_length(X, groups)
306+
301307
return groups
302308

303309
def _check_group_length(self, X: NDArray, groups: NDArray):
@@ -319,9 +325,11 @@ def _check_group_length(self, X: NDArray, groups: NDArray):
319325
If the number of individuals in the groups is not equal to the
320326
number of rows in X
321327
"""
322-
if len(groups) != X.shape[0]:
323-
raise ValueError("The number of individuals in the groups must " +
324-
"be equal to the number of rows in X")
328+
if len(groups) != len(X):
329+
raise ValueError(
330+
"The number of individuals in the groups must "
331+
"be equal to the number of rows in X"
332+
)
325333

326334
def _check_estimator(self):
327335
"""
@@ -405,15 +413,16 @@ def _check_fit_parameters(
405413
groups : NDArray of shape (n_samples,)
406414
The group values
407415
"""
408-
409416
self._check_estimator()
410417
self._check_cv()
411418
self._check_confomity_score()
419+
412420
X, y = indexable(X, y)
413421
y = _check_y(y)
414422
X = cast(NDArray, X)
415423
y = cast(NDArray, y)
416424
groups = cast(NDArray, np.array(groups))
425+
417426
self._check_groups_fit(X, groups)
418427

419428
return X, y, groups

0 commit comments

Comments
 (0)