2929
3030class MondrianCP (BaseEstimator ):
3131 """Mondrian is a method for making conformal predictions
32- for disjoints groups of individuals.
32+ for partition of individuals.
3333
3434 The Mondrian method is implemented in the `MondrianCP` class. It takes as
3535 input a `MapieClassifier` or `MapieRegressor` estimator and fits a model
@@ -54,7 +54,7 @@ class MondrianCP(BaseEstimator):
5454
5555 Attributes
5656 ----------
57- unique_groups : NDArray
57+ partition_groups : NDArray
5858 The unique groups of individuals for which the estimator was fitted
5959
6060 mapie_estimators : Dict
@@ -74,11 +74,12 @@ class MondrianCP(BaseEstimator):
7474 >>> from mapie.classification import MapieClassifier
7575 >>> X_toy = np.arange(9).reshape(-1, 1)
7676 >>> y_toy = np.stack([0, 0, 1, 0, 1, 2, 1, 2, 2])
77- >>> groups_toy = [0, 0, 0, 0, 1, 1, 1, 1, 1]
77+ >>> partition_toy = [0, 0, 0, 0, 1, 1, 1, 1, 1]
7878 >>> clf = LogisticRegression(random_state=42).fit(X_toy, y_toy)
7979 >>> mapie = MondrianCP(MapieClassifier(estimator=clf, cv="prefit")).fit(
80- ... X_toy, y_toy, groups_toy)
81- >>> _, y_pi_mapie = mapie.predict(X_toy, alpha=0.4, groups=groups_toy)
80+ ... X_toy, y_toy, partition_toy)
81+ >>> _, y_pi_mapie = mapie.predict(
82+ ... X_toy, partition_toy, alpha=[0.1, 0.9])
8283 >>> print(y_pi_mapie[:, :, 0].astype(bool))
8384 [[ True False False]
8485 [ True False False]
@@ -108,7 +109,7 @@ class MondrianCP(BaseEstimator):
108109 AbsoluteConformityScore , GammaConformityScore
109110 )
110111 fit_attributes = [
111- "unique_groups " ,
112+ "partition_groups " ,
112113 "mapie_estimators"
113114 ]
114115
@@ -121,7 +122,7 @@ def __init__(
121122 def fit (
122123 self , X : ArrayLike ,
123124 y : ArrayLike ,
124- groups : ArrayLike ,
125+ partition : ArrayLike ,
125126 ** fit_params
126127 ) -> MondrianCP :
127128 """
@@ -135,7 +136,7 @@ def fit(
135136 y : ArrayLike of shape (n_samples,) or (n_samples, n_outputs)
136137 The target values
137138
138- groups : ArrayLike of shape (n_samples,)
139+ partition : ArrayLike of shape (n_samples,)
139140 The groups of individuals. Must be defined by integers. There must
140141 be at least 2 individuals per group.
141142
@@ -144,16 +145,16 @@ def fit(
144145 that may be specific to the Mapie estimator used
145146 """
146147
147- X , y , groups = self ._check_fit_parameters (X , y , groups )
148- self .unique_groups = np .unique (groups )
148+ X , y , partition = self ._check_fit_parameters (X , y , partition )
149+ self .partition_groups = np .unique (partition )
149150 self .mapie_estimators = {}
150151
151152 if isinstance (self .mapie_estimator , MapieClassifier ):
152153 self .n_classes = len (np .unique (y ))
153154
154- for group in self .unique_groups :
155+ for group in self .partition_groups :
155156 mapie_group_estimator = copy (self .mapie_estimator )
156- indices_groups = np .argwhere (groups == group )[:, 0 ]
157+ indices_groups = np .argwhere (partition == group )[:, 0 ]
157158 X_g = [X [index ] for index in indices_groups ]
158159 y_g = [y [index ] for index in indices_groups ]
159160 mapie_group_estimator .fit (X_g , y_g , ** fit_params )
@@ -164,7 +165,7 @@ def fit(
164165 def predict (
165166 self ,
166167 X : ArrayLike ,
167- groups : ArrayLike ,
168+ partition : ArrayLike ,
168169 alpha : Optional [Union [float , Iterable [float ]]] = None ,
169170 ** predict_params
170171 ) -> Union [NDArray , Tuple [NDArray , NDArray ]]:
@@ -176,7 +177,7 @@ def predict(
176177 X : ArrayLike of shape (n_samples, n_features)
177178 The input data
178179
179- groups : ArrayLike of shape (n_samples,), optional
180+ partition : ArrayLike of shape (n_samples,), optional
180181 The groups of individuals. Must be defined by integers.
181182
182183 By default None.
@@ -214,11 +215,11 @@ def predict(
214215 y_pred = np .empty ((len (X ),))
215216 y_pss = np .empty ((len (X ), 2 , len (alpha_np )))
216217
217- groups = self ._check_groups_predict (X , groups )
218- unique_groups = np .unique (groups )
218+ partition = self ._check_partition_predict (X , partition )
219+ partition_groups = np .unique (partition )
219220
220- for _ , group in enumerate (unique_groups ):
221- indices_groups = np .argwhere (groups == group )[:, 0 ]
221+ for _ , group in enumerate (partition_groups ):
222+ indices_groups = np .argwhere (partition == group )[:, 0 ]
222223 X_g = [X [index ] for index in indices_groups ]
223224 y_pred_g , y_pss_g = self .mapie_estimators [group ].predict (
224225 X_g , alpha = alpha_np , ** predict_params
@@ -243,7 +244,7 @@ def _check_cv(self):
243244 "estimator uses cv='prefit'."
244245 )
245246
246- def _check_groups_fit (self , X : NDArray , groups : NDArray ):
247+ def _check_partition_fit (self , X : NDArray , partition : NDArray ):
247248 """
248249 Check that each group is defined by an integer and check that there
249250 are at least 2 individuals per group
@@ -253,59 +254,63 @@ def _check_groups_fit(self, X: NDArray, groups: NDArray):
253254 X : NDArray of shape (n_samples, n_features)
254255 The input data
255256
256- groups : NDArray of shape (n_samples,)
257+ partition : NDArray of shape (n_samples,)
257258
258259 Raises
259260 ------
260261 ValueError
261- If the groups are not defined by integers
262+ If the partition is not defined by integers
262263 If there is less than 2 individuals per group
263- If the number of individuals in the groups is not equal to the
264+ If the number of individuals in the partition is not equal to the
264265 number of rows in X
265266 """
266- if not np .issubdtype (groups .dtype , np .integer ):
267- raise ValueError ("The groups must be defined by integers" )
267+ if not np .issubdtype (partition .dtype , np .integer ):
268+ raise ValueError ("The partition must be defined by integers" )
268269
269- _ , counts = np .unique (groups , return_counts = True )
270+ _ , counts = np .unique (partition , return_counts = True )
270271 if np .min (counts ) < 2 :
271272 raise ValueError ("There must be at least 2 individuals per group" )
272273
273- self ._check_group_length (X , groups )
274+ self ._check_partition_length (X , partition )
274275
275- def _check_groups_predict (self , X : NDArray , groups : ArrayLike ) -> NDArray :
276+ def _check_partition_predict (
277+ self ,
278+ X : NDArray ,
279+ partition : ArrayLike
280+ ) -> NDArray :
276281 """
277282 Check that there is no new group in the prediction and that
278- the number of individuals in the groups is equal to the number of
283+ the number of individuals in the partition is equal to the number of
279284 rows in X
280285
281286 Parameters
282287 ----------
283288 X : NDArray of shape (n_samples, n_features)
284289 The input data
285290
286- groups : ArrayLike of shape (n_samples,)
291+ partition : ArrayLike of shape (n_samples,)
287292 The groups of individuals. Must be defined by integers
288293
289294 Returns
290295 -------
291- groups : NDArray of shape (n_samples,)
292- Groups of individuals
296+ partition : NDArray of shape (n_samples,)
297+ Partition of the dataset
293298
294299 Raises
295300 ------
296301 ValueError
297302 If there is a new group in the prediction
298303 """
299- groups = cast (NDArray , np .array (groups ))
300- if not np .all (np .isin (groups , self .unique_groups )):
304+ partition = cast (NDArray , np .array (partition ))
305+ if not np .all (np .isin (partition , self .partition_groups )):
301306 raise ValueError (
302307 "There is at least one new group in the prediction."
303308 )
304- self ._check_group_length (X , groups )
309+ self ._check_partition_length (X , partition )
305310
306- return groups
311+ return partition
307312
308- def _check_group_length (self , X : NDArray , groups : NDArray ):
313+ def _check_partition_length (self , X : NDArray , partition : NDArray ):
309314 """
310315 Check that the number of rows in the groups array is equal to
311316 the number of rows in the attributes array.
@@ -315,18 +320,18 @@ def _check_group_length(self, X: NDArray, groups: NDArray):
315320 X : NDArray of shape (n_samples, n_features)
316321 The individual data.
317322
318- groups : NDArray of shape (n_samples,)
323+ partition : NDArray of shape (n_samples,)
319324 The groups of individuals. Must be defined by integers
320325
321326 Raises
322327 ------
323328 ValueError
324- If the number of individuals in the groups is not equal to the
329+ If the number of individuals in the partition is not equal to the
325330 number of rows in X
326331 """
327- if len (groups ) != len (X ):
332+ if len (partition ) != len (X ):
328333 raise ValueError (
329- "The number of individuals in the groups must "
334+ "The number of individuals in the partition must "
330335 "be equal to the number of rows in X"
331336 )
332337
@@ -385,10 +390,10 @@ def _check_confomity_score(self):
385390 )
386391
387392 def _check_fit_parameters (
388- self , X : ArrayLike , y : ArrayLike , groups : ArrayLike
393+ self , X : ArrayLike , y : ArrayLike , partition : ArrayLike
389394 ) -> Tuple [NDArray , NDArray , NDArray ]:
390395 """
391- Perform checks on the input data, groups and the estimator
396+ Perform checks on the input data, partition and the estimator
392397
393398 Parameters
394399 ----------
@@ -398,7 +403,7 @@ def _check_fit_parameters(
398403 y : ArrayLike of shape (n_samples,) or (n_samples, n_outputs)
399404 The target values
400405
401- groups : ArrayLike of shape (n_samples,)
406+ partition : ArrayLike of shape (n_samples,)
402407 The groups of individuals. Must be defined by integers
403408
404409 Returns
@@ -409,7 +414,7 @@ def _check_fit_parameters(
409414 y : NDArray of shape (n_samples,) or (n_samples, n_outputs)
410415 The target values
411416
412- groups : NDArray of shape (n_samples,)
417+ partition : NDArray of shape (n_samples,)
413418 The group values
414419 """
415420 self ._check_estimator ()
@@ -420,9 +425,9 @@ def _check_fit_parameters(
420425 y = _check_y (y )
421426 X = cast (NDArray , X )
422427 y = cast (NDArray , y )
423- groups = cast (NDArray , np .array (groups ))
428+ partition = cast (NDArray , np .array (partition ))
424429
425- self ._check_groups_fit (X , groups )
426- self ._check_group_length (X , groups )
430+ self ._check_partition_fit (X , partition )
431+ self ._check_partition_length (X , partition )
427432
428- return X , y , groups
433+ return X , y , partition
0 commit comments