Skip to content

Commit 044ae69

Browse files
Merge pull request #445 from BaptisteCalot/master
Building of Ensemble Classifier from MapieClassifier
2 parents a8f80a6 + 6bbb59c commit 044ae69

File tree

14 files changed

+854
-469
lines changed

14 files changed

+854
-469
lines changed

mapie/classification.py

Lines changed: 219 additions & 284 deletions
Large diffs are not rendered by default.

mapie/conformity_scores/checks.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Optional
2+
3+
from .conformity_scores import ConformityScore
4+
from .residual_conformity_scores import AbsoluteConformityScore
5+
6+
7+
def check_conformity_score(
8+
conformity_score: Optional[ConformityScore],
9+
sym: bool = True,
10+
) -> ConformityScore:
11+
"""
12+
Check parameter ``conformity_score``.
13+
14+
Raises
15+
------
16+
ValueError
17+
If parameter is not valid.
18+
19+
Examples
20+
--------
21+
>>> from mapie.conformity_scores.checks import check_conformity_score
22+
>>> try:
23+
... check_conformity_score(1)
24+
... except Exception as exception:
25+
... print(exception)
26+
...
27+
Invalid conformity_score argument.
28+
Must be None or a ConformityScore instance.
29+
"""
30+
if conformity_score is None:
31+
return AbsoluteConformityScore(sym=sym)
32+
elif isinstance(conformity_score, ConformityScore):
33+
return conformity_score
34+
else:
35+
raise ValueError(
36+
"Invalid conformity_score argument.\n"
37+
"Must be None or a ConformityScore instance."
38+
)

mapie/conformity_scores/conformity_scores.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mapie._compatibility import np_nanquantile
77
from mapie._typing import ArrayLike, NDArray
8-
from mapie.estimator.interface import EnsembleEstimator
8+
from mapie.estimator.regressor import EnsembleRegressor
99

1010

1111
class ConformityScore(metaclass=ABCMeta):
@@ -338,7 +338,7 @@ def _beta_optimize(
338338
def get_bounds(
339339
self,
340340
X: ArrayLike,
341-
estimator: EnsembleEstimator,
341+
estimator: EnsembleRegressor,
342342
conformity_scores: NDArray,
343343
alpha_np: NDArray,
344344
ensemble: bool = False,
@@ -348,14 +348,14 @@ def get_bounds(
348348
) -> Tuple[NDArray, NDArray, NDArray]:
349349
"""
350350
Compute bounds of the prediction intervals from the observed values,
351-
the estimator of type ``EnsembleEstimator`` and the conformity scores.
351+
the estimator of type ``EnsembleRegressor`` and the conformity scores.
352352
353353
Parameters
354354
----------
355355
X: ArrayLike of shape (n_samples, n_features)
356356
Observed feature values.
357357
358-
estimator: EnsembleEstimator
358+
estimator: EnsembleRegressor
359359
Estimator that is fitted to predict y from X.
360360
361361
conformity_scores: ArrayLike of shape (n_samples,)

mapie/estimator/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .interface import EnsembleEstimator
2+
from .regressor import EnsembleRegressor
3+
from .classifier import EnsembleClassifier
4+
5+
__all__ = [
6+
"EnsembleEstimator",
7+
"EnsembleRegressor",
8+
"EnsembleClassifier",
9+
]

0 commit comments

Comments
 (0)