Skip to content

Commit 2b6c211

Browse files
MTN: move mapie_v1/ conformity scores related code in mapie/, remove mapie_v1/_version.py (#676)
1 parent 6fc8ebf commit 2b6c211

File tree

8 files changed

+83
-52
lines changed

8 files changed

+83
-52
lines changed

mapie/classification.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222
from mapie.conformity_scores.sets.raps import RAPSConformityScore
2323
from mapie.conformity_scores.utils import (
2424
check_depreciated_size_raps, check_classification_conformity_score,
25-
check_target
25+
check_target, check_and_select_conformity_score,
2626
)
2727
from mapie.estimator.classifier import EnsembleClassifier
2828
from mapie.utils import (_check_alpha, _check_alpha_and_n_samples, _check_cv,
2929
_check_estimator_classification, _check_n_features_in,
3030
_check_n_jobs, _check_null_weight, _check_predict_params,
3131
_check_verbose)
32-
from mapie_v1.conformity_scores._utils import check_and_select_conformity_score
3332
from mapie.utils import (
3433
_transform_confidence_level_to_alpha_list,
3534
_raise_error_if_fit_called_in_prefit_mode,

mapie/conformity_scores/utils.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,53 @@
1-
from typing import Optional
1+
from typing import Optional, no_type_check
22
import warnings
33

4-
from sklearn.utils.multiclass import (check_classification_targets,
5-
type_of_target)
4+
from sklearn.utils.multiclass import (
5+
check_classification_targets,
6+
type_of_target,
7+
)
68

79
from .regression import BaseRegressionScore
810
from .classification import BaseClassificationScore
9-
from .bounds import AbsoluteConformityScore
10-
from .sets import LACConformityScore
11+
from .bounds import (
12+
AbsoluteConformityScore,
13+
GammaConformityScore,
14+
ResidualNormalisedScore,
15+
)
16+
from .sets import (
17+
LACConformityScore,
18+
TopKConformityScore,
19+
APSConformityScore,
20+
RAPSConformityScore,
21+
)
1122

1223
from numpy.typing import ArrayLike
1324

1425

26+
CONFORMITY_SCORES_STRING_MAP = {
27+
BaseRegressionScore: {
28+
"absolute": AbsoluteConformityScore,
29+
"gamma": GammaConformityScore,
30+
"residual_normalized": ResidualNormalisedScore,
31+
},
32+
BaseClassificationScore: {
33+
"lac": LACConformityScore,
34+
"top_k": TopKConformityScore,
35+
"aps": APSConformityScore,
36+
"raps": RAPSConformityScore,
37+
},
38+
}
39+
40+
41+
@no_type_check # Cumbersome to type
42+
def check_and_select_conformity_score(conformity_score, conformity_score_type):
43+
if isinstance(conformity_score, conformity_score_type):
44+
return conformity_score
45+
elif conformity_score in CONFORMITY_SCORES_STRING_MAP[conformity_score_type]:
46+
return CONFORMITY_SCORES_STRING_MAP[conformity_score_type][conformity_score]()
47+
else:
48+
raise ValueError("Invalid conformity_score parameter")
49+
50+
1551
def check_regression_conformity_score(
1652
conformity_score: Optional[BaseRegressionScore],
1753
sym: bool = True,

mapie/regression/regression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
from numpy.typing import ArrayLike, NDArray
1515
from mapie.conformity_scores import (BaseRegressionScore,
1616
ResidualNormalisedScore)
17-
from mapie.conformity_scores.utils import check_regression_conformity_score
17+
from mapie.conformity_scores.utils import (
18+
check_regression_conformity_score,
19+
check_and_select_conformity_score,
20+
)
1821
from mapie.estimator.regressor import EnsembleRegressor
1922
from mapie.subsample import Subsample
2023
from mapie.utils import (_check_alpha, _check_alpha_and_n_samples,
2124
_check_cv, _check_estimator_fit_predict,
2225
_check_n_features_in, _check_n_jobs, _check_null_weight,
2326
_check_verbose, _get_effective_calibration_samples,
2427
_check_predict_params)
25-
from mapie_v1.conformity_scores._utils import check_and_select_conformity_score
2628
from mapie.utils import (
2729
_transform_confidence_level_to_alpha_list,
2830
_check_if_param_in_allowed_values,

mapie_v1/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

mapie_v1/_version.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

mapie_v1/conformity_scores/__init__.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

mapie_v1/conformity_scores/_utils.py

Lines changed: 0 additions & 12 deletions
This file was deleted.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
3+
from mapie.conformity_scores.utils import (
4+
check_and_select_conformity_score,
5+
)
6+
from mapie.conformity_scores.regression import BaseRegressionScore
7+
from mapie.conformity_scores.classification import BaseClassificationScore
8+
from mapie.conformity_scores.bounds import (
9+
AbsoluteConformityScore,
10+
GammaConformityScore,
11+
)
12+
from mapie.conformity_scores.sets import (
13+
LACConformityScore,
14+
TopKConformityScore,
15+
)
16+
17+
18+
class TestCheckAndSelectConformityScore:
19+
20+
@pytest.mark.parametrize(
21+
"score, score_type, expected_class", [
22+
(AbsoluteConformityScore(), BaseRegressionScore, AbsoluteConformityScore),
23+
("gamma", BaseRegressionScore, GammaConformityScore),
24+
(LACConformityScore(), BaseClassificationScore, LACConformityScore),
25+
("top_k", BaseClassificationScore, TopKConformityScore),
26+
]
27+
)
28+
def test_with_valid_inputs(self, score, score_type, expected_class):
29+
result = check_and_select_conformity_score(score, score_type)
30+
assert isinstance(result, expected_class)
31+
32+
@pytest.mark.parametrize(
33+
"score_type", [BaseRegressionScore, BaseClassificationScore]
34+
)
35+
def test_with_invalid_input(self, score_type):
36+
with pytest.raises(ValueError):
37+
check_and_select_conformity_score("I'm not a valid input :(", score_type)

0 commit comments

Comments
 (0)