11from __future__ import annotations
22
3- from copy import deepcopy
3+ from copy import copy
44from typing import Iterable , Optional , Tuple , Union , cast
55
66import numpy as np
@@ -148,21 +148,24 @@ def fit(
148148 self ._check_group_length (X , groups )
149149 self .unique_groups = np .unique (groups )
150150 self .mapie_estimators = {}
151+
151152 if isinstance (self .mapie_estimator , MapieClassifier ):
152153 self .n_classes = len (np .unique (y ))
153154
154155 for group in self .unique_groups :
155- mapie_group_estimator = deepcopy (self .mapie_estimator )
156+ mapie_group_estimator = copy (self .mapie_estimator )
156157 indices_groups = np .argwhere (groups == group )[:, 0 ]
157- X_g , y_g = X [indices_groups ], y [indices_groups ]
158+ X_g = [X [index ] for index in indices_groups ]
159+ y_g = [y [index ] for index in indices_groups ]
158160 mapie_group_estimator .fit (X_g , y_g , ** fit_params )
159161 self .mapie_estimators [group ] = mapie_group_estimator
162+
160163 return self
161164
162165 def predict (
163166 self ,
164167 X : ArrayLike ,
165- groups : ArrayLike ,
168+ groups : Optional [ ArrayLike ] = None ,
166169 alpha : Optional [Union [float , Iterable [float ]]] = None ,
167170 ** predict_params
168171 ) -> Union [NDArray , Tuple [NDArray , NDArray ]]:
@@ -174,9 +177,11 @@ def predict(
174177 X : ArrayLike of shape (n_samples, n_features)
175178 The input data
176179
177- groups : ArrayLike of shape (n_samples,)
180+ groups : ArrayLike of shape (n_samples,), optional
178181 The groups of individuals. Must be defined by integers.
179182
183+ By default None.
184+
180185 alpha : float or Iterable[float], optional
181186 The desired coverage level(s) for each group.
182187
@@ -194,37 +199,35 @@ def predict(
194199 y_pss : NDArray of shape (n_samples, n_outputs, n_alpha)
195200 The predicted sets for the desired levels of coverage
196201 """
197-
198202 check_is_fitted (self , self .fit_attributes )
199203 X = cast (NDArray , X )
200- groups = self ._check_groups_predict (X , groups )
201204 alpha_np = cast (NDArray , check_alpha (alpha ))
205+
202206 if alpha_np is None and self .mapie_estimator .estimator is not None :
203207 return self .mapie_estimator .estimator .predict (
204208 X , ** predict_params
205209 )
210+
211+ if isinstance (self .mapie_estimator , MapieClassifier ):
212+ y_pred = np .empty ((len (X ), ))
213+ y_pss = np .empty ((len (X ), self .n_classes , len (alpha_np )))
206214 else :
207- if isinstance (self .mapie_estimator , MapieClassifier ):
208- y_pred = np .empty (
209- (X .shape [0 ], )
210- )
211- y_pss = np .empty (
212- (X .shape [0 ], self .n_classes , len (alpha_np ))
213- )
214- else :
215- y_pred = np .empty ((X .shape [0 ],))
216- y_pss = np .empty ((X .shape [0 ], 2 , len (alpha_np )))
217- unique_groups = np .unique (groups )
218- for i , group in enumerate (unique_groups ):
219- indices_groups = np .argwhere (groups == group )[:, 0 ]
220- X_g = X [indices_groups ]
221- y_pred_g , y_pss_g = self .mapie_estimators [group ].predict (
222- X_g , alpha = alpha_np , ** predict_params
223- )
224- y_pred [indices_groups ] = y_pred_g
225- y_pss [indices_groups ] = y_pss_g
226-
227- return y_pred , y_pss
215+ y_pred = np .empty ((len (X ),))
216+ y_pss = np .empty ((len (X ), 2 , len (alpha_np )))
217+
218+ groups = self ._check_groups_predict (X , groups )
219+ unique_groups = np .unique (groups )
220+
221+ for _ , group in enumerate (unique_groups ):
222+ indices_groups = np .argwhere (groups == group )[:, 0 ]
223+ X_g = [X [index ] for index in indices_groups ]
224+ y_pred_g , y_pss_g = self .mapie_estimators [group ].predict (
225+ X_g , alpha = alpha_np , ** predict_params
226+ )
227+ y_pred [indices_groups ] = y_pred_g
228+ y_pss [indices_groups ] = y_pss_g
229+
230+ return y_pred , y_pss
228231
229232 def _check_cv (self ):
230233 """
@@ -263,9 +266,11 @@ def _check_groups_fit(self, X: NDArray, groups: NDArray):
263266 """
264267 if not np .issubdtype (groups .dtype , np .integer ):
265268 raise ValueError ("The groups must be defined by integers" )
269+
266270 _ , counts = np .unique (groups , return_counts = True )
267271 if np .min (counts ) < 2 :
268272 raise ValueError ("There must be at least 2 individuals per group" )
273+
269274 self ._check_group_length (X , groups )
270275
271276 def _check_groups_predict (self , X : NDArray , groups : ArrayLike ) -> NDArray :
@@ -295,9 +300,10 @@ def _check_groups_predict(self, X: NDArray, groups: ArrayLike) -> NDArray:
295300 groups = cast (NDArray , np .array (groups ))
296301 if not np .all (np .isin (groups , self .unique_groups )):
297302 raise ValueError (
298- "There is at least one new group in the prediction"
303+ "There is at least one new group in the prediction. "
299304 )
300305 self ._check_group_length (X , groups )
306+
301307 return groups
302308
303309 def _check_group_length (self , X : NDArray , groups : NDArray ):
@@ -319,9 +325,11 @@ def _check_group_length(self, X: NDArray, groups: NDArray):
319325 If the number of individuals in the groups is not equal to the
320326 number of rows in X
321327 """
322- if len (groups ) != X .shape [0 ]:
323- raise ValueError ("The number of individuals in the groups must " +
324- "be equal to the number of rows in X" )
328+ if len (groups ) != len (X ):
329+ raise ValueError (
330+ "The number of individuals in the groups must "
331+ "be equal to the number of rows in X"
332+ )
325333
326334 def _check_estimator (self ):
327335 """
@@ -405,15 +413,16 @@ def _check_fit_parameters(
405413 groups : NDArray of shape (n_samples,)
406414 The group values
407415 """
408-
409416 self ._check_estimator ()
410417 self ._check_cv ()
411418 self ._check_confomity_score ()
419+
412420 X , y = indexable (X , y )
413421 y = _check_y (y )
414422 X = cast (NDArray , X )
415423 y = cast (NDArray , y )
416424 groups = cast (NDArray , np .array (groups ))
425+
417426 self ._check_groups_fit (X , groups )
418427
419428 return X , y , groups
0 commit comments