Skip to content

Commit d6ed656

Browse files
FIX: move check function to avoid circular import
1 parent 09628f3 commit d6ed656

File tree

7 files changed

+49
-49
lines changed

7 files changed

+49
-49
lines changed

mapie/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from mapie._machine_precision import EPSILON
1717
from mapie._typing import ArrayLike, NDArray
18-
from mapie.estimator import EnsembleClassifier
18+
from mapie.estimator.classifier import EnsembleClassifier
1919
from mapie.metrics import classification_mean_width_score
2020
from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_cv,
2121
check_estimator_classification, check_n_features_in,

mapie/conformity_scores/check.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Optional
2+
3+
from .conformity_scores import ConformityScore
4+
from .residual_conformity_scores import AbsoluteConformityScore
5+
6+
7+
def check_conformity_score(
8+
conformity_score: Optional[ConformityScore],
9+
sym: bool = True,
10+
) -> ConformityScore:
11+
"""
12+
Check parameter ``conformity_score``.
13+
14+
Raises
15+
------
16+
ValueError
17+
If parameter is not valid.
18+
19+
Examples
20+
--------
21+
>>> from mapie.utils import check_conformity_score
22+
>>> try:
23+
... check_conformity_score(1)
24+
... except Exception as exception:
25+
... print(exception)
26+
...
27+
Invalid conformity_score argument.
28+
Must be None or a ConformityScore instance.
29+
"""
30+
if conformity_score is None:
31+
return AbsoluteConformityScore(sym=sym)
32+
elif isinstance(conformity_score, ConformityScore):
33+
return conformity_score
34+
else:
35+
raise ValueError(
36+
"Invalid conformity_score argument.\n"
37+
"Must be None or a ConformityScore instance."
38+
)

mapie/conformity_scores/conformity_scores.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mapie._compatibility import np_nanquantile
77
from mapie._typing import ArrayLike, NDArray
8-
from mapie.estimator import EnsembleRegressor
8+
from mapie.estimator.regressor import EnsembleRegressor
99

1010

1111
class ConformityScore(metaclass=ABCMeta):
@@ -326,14 +326,14 @@ def get_bounds(
326326
) -> Tuple[NDArray, NDArray, NDArray]:
327327
"""
328328
Compute bounds of the prediction intervals from the observed values,
329-
the estimator of type ``EnsembleEstimator`` and the conformity scores.
329+
the estimator of type ``EnsembleRegressor`` and the conformity scores.
330330
331331
Parameters
332332
----------
333333
X: ArrayLike of shape (n_samples, n_features)
334334
Observed feature values.
335335
336-
estimator: EnsembleEstimator
336+
estimator: EnsembleRegressor
337337
Estimator that is fitted to predict y from X.
338338
339339
conformity_scores: ArrayLike of shape (n_samples,)

mapie/estimator/classifier.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111

1212
from mapie._typing import ArrayLike, NDArray
1313
from mapie.estimator.interface import EnsembleEstimator
14-
from mapie.utils import (
15-
check_no_agg_cv,
16-
fit_estimator,
17-
fix_number_of_classes,
18-
)
14+
from mapie.utils import check_no_agg_cv, fit_estimator, fix_number_of_classes
1915

2016

2117
class EnsembleClassifier(EnsembleEstimator):

mapie/regression/regression.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414
from mapie._typing import ArrayLike, NDArray
1515
from mapie.conformity_scores import ConformityScore, ResidualNormalisedScore
16-
from mapie.estimator import EnsembleRegressor
16+
from mapie.estimator.regressor import EnsembleRegressor
1717
from mapie.utils import (check_alpha, check_alpha_and_n_samples,
18-
check_conformity_score, check_cv,
19-
check_estimator_fit_predict, check_n_features_in,
20-
check_n_jobs, check_null_weight, check_verbose)
18+
check_cv, check_estimator_fit_predict,
19+
check_n_features_in, check_n_jobs, check_null_weight,
20+
check_verbose)
21+
from mapie.conformity_scores.check import check_conformity_score
2122

2223

2324
class MapieRegressor(BaseEstimator, RegressorMixin):

mapie/tests/test_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore,
2626
GammaConformityScore,
2727
ResidualNormalisedScore)
28-
from mapie.estimator import EnsembleRegressor
28+
from mapie.estimator.regressor import EnsembleRegressor
2929
from mapie.metrics import regression_coverage_score
3030
from mapie.regression import MapieRegressor
3131
from mapie.subsample import Subsample

mapie/utils.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from ._compatibility import np_quantile
1818
from ._typing import ArrayLike, NDArray
19-
from .conformity_scores import AbsoluteConformityScore, ConformityScore
2019

2120
SPLIT_STRATEGIES = ["uniform", "quantile", "array split"]
2221

@@ -600,40 +599,6 @@ def check_lower_upper_bounds(
600599
)
601600

602601

603-
def check_conformity_score(
604-
conformity_score: Optional[ConformityScore],
605-
sym: bool = True,
606-
) -> ConformityScore:
607-
"""
608-
Check parameter ``conformity_score``.
609-
610-
Raises
611-
------
612-
ValueError
613-
If parameter is not valid.
614-
615-
Examples
616-
--------
617-
>>> from mapie.utils import check_conformity_score
618-
>>> try:
619-
... check_conformity_score(1)
620-
... except Exception as exception:
621-
... print(exception)
622-
...
623-
Invalid conformity_score argument.
624-
Must be None or a ConformityScore instance.
625-
"""
626-
if conformity_score is None:
627-
return AbsoluteConformityScore(sym=sym)
628-
elif isinstance(conformity_score, ConformityScore):
629-
return conformity_score
630-
else:
631-
raise ValueError(
632-
"Invalid conformity_score argument.\n"
633-
"Must be None or a ConformityScore instance."
634-
)
635-
636-
637602
def check_defined_variables_predict_cqr(
638603
ensemble: bool,
639604
alpha: Union[float, Iterable[float], None],

0 commit comments

Comments
 (0)