Skip to content

Commit d6aa546

Browse files
committed
ENH: rename groups into partition
1 parent 5979dcb commit d6aa546

File tree

1 file changed

+53
-48
lines changed

1 file changed

+53
-48
lines changed

mapie/mondrian.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
class MondrianCP(BaseEstimator):
3131
"""Mondrian is a method for making conformal predictions
32-
for disjoints groups of individuals.
32+
for partition of individuals.
3333
3434
The Mondrian method is implemented in the `MondrianCP` class. It takes as
3535
input a `MapieClassifier` or `MapieRegressor` estimator and fits a model
@@ -54,7 +54,7 @@ class MondrianCP(BaseEstimator):
5454
5555
Attributes
5656
----------
57-
unique_groups : NDArray
57+
partition_groups : NDArray
5858
The unique groups of individuals for which the estimator was fitted
5959
6060
mapie_estimators : Dict
@@ -74,11 +74,12 @@ class MondrianCP(BaseEstimator):
7474
>>> from mapie.classification import MapieClassifier
7575
>>> X_toy = np.arange(9).reshape(-1, 1)
7676
>>> y_toy = np.stack([0, 0, 1, 0, 1, 2, 1, 2, 2])
77-
>>> groups_toy = [0, 0, 0, 0, 1, 1, 1, 1, 1]
77+
>>> partition_toy = [0, 0, 0, 0, 1, 1, 1, 1, 1]
7878
>>> clf = LogisticRegression(random_state=42).fit(X_toy, y_toy)
7979
>>> mapie = MondrianCP(MapieClassifier(estimator=clf, cv="prefit")).fit(
80-
... X_toy, y_toy, groups_toy)
81-
>>> _, y_pi_mapie = mapie.predict(X_toy, alpha=0.4, groups=groups_toy)
80+
... X_toy, y_toy, partition_toy)
81+
>>> _, y_pi_mapie = mapie.predict(
82+
... X_toy, partition_toy, alpha=[0.1, 0.9])
8283
>>> print(y_pi_mapie[:, :, 0].astype(bool))
8384
[[ True False False]
8485
[ True False False]
@@ -108,7 +109,7 @@ class MondrianCP(BaseEstimator):
108109
AbsoluteConformityScore, GammaConformityScore
109110
)
110111
fit_attributes = [
111-
"unique_groups",
112+
"partition_groups",
112113
"mapie_estimators"
113114
]
114115

@@ -121,7 +122,7 @@ def __init__(
121122
def fit(
122123
self, X: ArrayLike,
123124
y: ArrayLike,
124-
groups: ArrayLike,
125+
partition: ArrayLike,
125126
**fit_params
126127
) -> MondrianCP:
127128
"""
@@ -135,7 +136,7 @@ def fit(
135136
y : ArrayLike of shape (n_samples,) or (n_samples, n_outputs)
136137
The target values
137138
138-
groups : ArrayLike of shape (n_samples,)
139+
partition : ArrayLike of shape (n_samples,)
139140
The groups of individuals. Must be defined by integers. There must
140141
be at least 2 individuals per group.
141142
@@ -144,16 +145,16 @@ def fit(
144145
that may be specific to the Mapie estimator used
145146
"""
146147

147-
X, y, groups = self._check_fit_parameters(X, y, groups)
148-
self.unique_groups = np.unique(groups)
148+
X, y, partition = self._check_fit_parameters(X, y, partition)
149+
self.partition_groups = np.unique(partition)
149150
self.mapie_estimators = {}
150151

151152
if isinstance(self.mapie_estimator, MapieClassifier):
152153
self.n_classes = len(np.unique(y))
153154

154-
for group in self.unique_groups:
155+
for group in self.partition_groups:
155156
mapie_group_estimator = copy(self.mapie_estimator)
156-
indices_groups = np.argwhere(groups == group)[:, 0]
157+
indices_groups = np.argwhere(partition == group)[:, 0]
157158
X_g = [X[index] for index in indices_groups]
158159
y_g = [y[index] for index in indices_groups]
159160
mapie_group_estimator.fit(X_g, y_g, **fit_params)
@@ -164,7 +165,7 @@ def fit(
164165
def predict(
165166
self,
166167
X: ArrayLike,
167-
groups: ArrayLike,
168+
partition: ArrayLike,
168169
alpha: Optional[Union[float, Iterable[float]]] = None,
169170
**predict_params
170171
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
@@ -176,7 +177,7 @@ def predict(
176177
X : ArrayLike of shape (n_samples, n_features)
177178
The input data
178179
179-
groups : ArrayLike of shape (n_samples,), optional
180+
partition : ArrayLike of shape (n_samples,), optional
180181
The groups of individuals. Must be defined by integers.
181182
182183
By default None.
@@ -214,11 +215,11 @@ def predict(
214215
y_pred = np.empty((len(X),))
215216
y_pss = np.empty((len(X), 2, len(alpha_np)))
216217

217-
groups = self._check_groups_predict(X, groups)
218-
unique_groups = np.unique(groups)
218+
partition = self._check_partition_predict(X, partition)
219+
partition_groups = np.unique(partition)
219220

220-
for _, group in enumerate(unique_groups):
221-
indices_groups = np.argwhere(groups == group)[:, 0]
221+
for _, group in enumerate(partition_groups):
222+
indices_groups = np.argwhere(partition == group)[:, 0]
222223
X_g = [X[index] for index in indices_groups]
223224
y_pred_g, y_pss_g = self.mapie_estimators[group].predict(
224225
X_g, alpha=alpha_np, **predict_params
@@ -243,7 +244,7 @@ def _check_cv(self):
243244
"estimator uses cv='prefit'."
244245
)
245246

246-
def _check_groups_fit(self, X: NDArray, groups: NDArray):
247+
def _check_partition_fit(self, X: NDArray, partition: NDArray):
247248
"""
248249
Check that each group is defined by an integer and check that there
249250
are at least 2 individuals per group
@@ -253,59 +254,63 @@ def _check_groups_fit(self, X: NDArray, groups: NDArray):
253254
X : NDArray of shape (n_samples, n_features)
254255
The input data
255256
256-
groups : NDArray of shape (n_samples,)
257+
partition : NDArray of shape (n_samples,)
257258
258259
Raises
259260
------
260261
ValueError
261-
If the groups are not defined by integers
262+
If the partition is not defined by integers
262263
If there is less than 2 individuals per group
263-
If the number of individuals in the groups is not equal to the
264+
If the number of individuals in the partition is not equal to the
264265
number of rows in X
265266
"""
266-
if not np.issubdtype(groups.dtype, np.integer):
267-
raise ValueError("The groups must be defined by integers")
267+
if not np.issubdtype(partition.dtype, np.integer):
268+
raise ValueError("The partition must be defined by integers")
268269

269-
_, counts = np.unique(groups, return_counts=True)
270+
_, counts = np.unique(partition, return_counts=True)
270271
if np.min(counts) < 2:
271272
raise ValueError("There must be at least 2 individuals per group")
272273

273-
self._check_group_length(X, groups)
274+
self._check_partition_length(X, partition)
274275

275-
def _check_groups_predict(self, X: NDArray, groups: ArrayLike) -> NDArray:
276+
def _check_partition_predict(
277+
self,
278+
X: NDArray,
279+
partition: ArrayLike
280+
) -> NDArray:
276281
"""
277282
Check that there is no new group in the prediction and that
278-
the number of individuals in the groups is equal to the number of
283+
the number of individuals in the partition is equal to the number of
279284
rows in X
280285
281286
Parameters
282287
----------
283288
X : NDArray of shape (n_samples, n_features)
284289
The input data
285290
286-
groups : ArrayLike of shape (n_samples,)
291+
partition : ArrayLike of shape (n_samples,)
287292
The groups of individuals. Must be defined by integers
288293
289294
Returns
290295
-------
291-
groups : NDArray of shape (n_samples,)
292-
Groups of individuals
296+
partition : NDArray of shape (n_samples,)
297+
Partition of the dataset
293298
294299
Raises
295300
------
296301
ValueError
297302
If there is a new group in the prediction
298303
"""
299-
groups = cast(NDArray, np.array(groups))
300-
if not np.all(np.isin(groups, self.unique_groups)):
304+
partition = cast(NDArray, np.array(partition))
305+
if not np.all(np.isin(partition, self.partition_groups)):
301306
raise ValueError(
302307
"There is at least one new group in the prediction."
303308
)
304-
self._check_group_length(X, groups)
309+
self._check_partition_length(X, partition)
305310

306-
return groups
311+
return partition
307312

308-
def _check_group_length(self, X: NDArray, groups: NDArray):
313+
def _check_partition_length(self, X: NDArray, partition: NDArray):
309314
"""
310315
Check that the number of rows in the groups array is equal to
311316
the number of rows in the attributes array.
@@ -315,18 +320,18 @@ def _check_group_length(self, X: NDArray, groups: NDArray):
315320
X : NDArray of shape (n_samples, n_features)
316321
The individual data.
317322
318-
groups : NDArray of shape (n_samples,)
323+
partition : NDArray of shape (n_samples,)
319324
The groups of individuals. Must be defined by integers
320325
321326
Raises
322327
------
323328
ValueError
324-
If the number of individuals in the groups is not equal to the
329+
If the number of individuals in the partition is not equal to the
325330
number of rows in X
326331
"""
327-
if len(groups) != len(X):
332+
if len(partition) != len(X):
328333
raise ValueError(
329-
"The number of individuals in the groups must "
334+
"The number of individuals in the partition must "
330335
"be equal to the number of rows in X"
331336
)
332337

@@ -385,10 +390,10 @@ def _check_confomity_score(self):
385390
)
386391

387392
def _check_fit_parameters(
388-
self, X: ArrayLike, y: ArrayLike, groups: ArrayLike
393+
self, X: ArrayLike, y: ArrayLike, partition: ArrayLike
389394
) -> Tuple[NDArray, NDArray, NDArray]:
390395
"""
391-
Perform checks on the input data, groups and the estimator
396+
Perform checks on the input data, partition and the estimator
392397
393398
Parameters
394399
----------
@@ -398,7 +403,7 @@ def _check_fit_parameters(
398403
y : ArrayLike of shape (n_samples,) or (n_samples, n_outputs)
399404
The target values
400405
401-
groups : ArrayLike of shape (n_samples,)
406+
partition : ArrayLike of shape (n_samples,)
402407
The groups of individuals. Must be defined by integers
403408
404409
Returns
@@ -409,7 +414,7 @@ def _check_fit_parameters(
409414
y : NDArray of shape (n_samples,) or (n_samples, n_outputs)
410415
The target values
411416
412-
groups : NDArray of shape (n_samples,)
417+
partition : NDArray of shape (n_samples,)
413418
The group values
414419
"""
415420
self._check_estimator()
@@ -420,9 +425,9 @@ def _check_fit_parameters(
420425
y = _check_y(y)
421426
X = cast(NDArray, X)
422427
y = cast(NDArray, y)
423-
groups = cast(NDArray, np.array(groups))
428+
partition = cast(NDArray, np.array(partition))
424429

425-
self._check_groups_fit(X, groups)
426-
self._check_group_length(X, groups)
430+
self._check_partition_fit(X, partition)
431+
self._check_partition_length(X, partition)
427432

428-
return X, y, groups
433+
return X, y, partition

0 commit comments

Comments
 (0)