Skip to content

Commit 7e1889e

Browse files
authored
chore: handle expected warnings in unit tests (#838)
* filter RuntimeWarning: invalid value encountered in divide * filter expected UserWarning * ignore VennAbers's class warning * filter sklearn warning * remove test_va_inductive_missing_size_parameters_raises_error
1 parent d0f1106 commit 7e1889e

File tree

5 files changed

+41
-6
lines changed

5 files changed

+41
-6
lines changed

mapie/_venn_abers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy as np
2+
from sklearn.base import clone
23
from sklearn.model_selection import StratifiedKFold, train_test_split
34
from sklearn.multiclass import OneVsOneClassifier
4-
from sklearn.base import clone
5-
65

76
"""
87
Private module containing core Venn-ABERS implementation classes.
@@ -284,6 +283,8 @@ def predict_proba_prefitted_va(
284283
285284
Examples
286285
--------
286+
>>> import warnings
287+
>>> warnings.filterwarnings("ignore")
287288
>>> import numpy as np
288289
>>> # Calibration data
289290
>>> p_cal = np.array([[0.7, 0.2, 0.1], [0.3, 0.6, 0.1], [0.1, 0.1, 0.8]])
@@ -415,6 +416,8 @@ class VennAbers:
415416
416417
Examples
417418
--------
419+
>>> import warnings
420+
>>> warnings.filterwarnings("ignore")
418421
>>> import numpy as np
419422
>>> from sklearn.datasets import make_classification
420423
>>> from sklearn.model_selection import train_test_split

mapie/calibration.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Dict, Optional, Tuple, Union, cast
54
from inspect import signature
5+
from typing import Dict, Optional, Tuple, Union, cast
6+
67
import numpy as np
78
from numpy.typing import ArrayLike, NDArray
89
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, clone
910
from sklearn.calibration import _SigmoidCalibration
10-
from sklearn.isotonic import IsotonicRegression
1111
from sklearn.exceptions import NotFittedError
12+
from sklearn.isotonic import IsotonicRegression
1213
from sklearn.pipeline import Pipeline
1314
from sklearn.utils import check_random_state
1415
from sklearn.utils.multiclass import type_of_target
1516
from sklearn.utils.validation import _check_y, _num_samples, indexable
1617

18+
from ._venn_abers import VennAbers, VennAbersMultiClass, predict_proba_prefitted_va
1719
from .utils import (
1820
_check_estimator_classification,
1921
_check_estimator_fit_predict,
@@ -24,8 +26,6 @@
2426
check_is_fitted,
2527
)
2628

27-
from ._venn_abers import predict_proba_prefitted_va, VennAbers, VennAbersMultiClass
28-
2929

3030
class TopLabelCalibrator(BaseEstimator, ClassifierMixin):
3131
"""
@@ -674,6 +674,8 @@ class VennAbersCalibrator(BaseEstimator, ClassifierMixin):
674674
675675
Examples
676676
--------
677+
>>> import warnings
678+
>>> warnings.filterwarnings("ignore")
677679
>>> import numpy as np
678680
>>> from sklearn.datasets import make_classification
679681
>>> from sklearn.model_selection import train_test_split

mapie/tests/test_calibration.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def test_not_seen_calibrator() -> None:
214214

215215
@pytest.mark.parametrize("calibrator", CALIBRATORS)
216216
@pytest.mark.parametrize("estimator", ESTIMATORS)
217+
@pytest.mark.filterwarnings("ignore:.*predicted label.*not been seen.*:UserWarning")
217218
def test_shape_of_output(
218219
calibrator: Union[str, RegressorMixin], estimator: ClassifierMixin
219220
) -> None:
@@ -452,6 +453,7 @@ def early_stopping_monitor(i, est, locals):
452453

453454

454455
@pytest.mark.parametrize("cv", ["prefit", None])
456+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
455457
def test_va_valid_cv_argument(cv: Optional[str]) -> None:
456458
"""Test valid cv methods."""
457459
if cv == "prefit":
@@ -545,6 +547,7 @@ def __getitem__(self, ind):
545547
),
546548
],
547549
)
550+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
548551
def test_va_mode_functionality(
549552
mode, mode_params, X_train, y_train, X_test, n_classes
550553
) -> None:
@@ -579,6 +582,7 @@ def test_va_mode_functionality(
579582
(X_multi_proper, y_multi_proper, X_multi_cal, y_multi_cal, X_multi_test, 3),
580583
],
581584
)
585+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
582586
def test_va_prefit_mode(X_proper, y_proper, X_cal, y_cal, X_test, n_classes) -> None:
583587
"""Test prefit mode for binary and multiclass."""
584588
clf = GaussianNB().fit(X_proper, y_proper)
@@ -601,6 +605,7 @@ def test_va_cross_validation_requires_n_splits() -> None:
601605

602606

603607
@pytest.mark.parametrize("estimator", VA_ESTIMATORS)
608+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
604609
def test_va_different_estimators(estimator) -> None:
605610
"""Test VennAbersCalibrator with different base estimators."""
606611
va_cal = VennAbersCalibrator(
@@ -619,6 +624,7 @@ def test_va_estimator_none_raises_error() -> None:
619624
va_cal.fit(X_binary_train, y_binary_train)
620625

621626

627+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
622628
def test_va_sample_weights_constant() -> None:
623629
"""Test that constant sample weights give same results as None."""
624630
sklearn.set_config(enable_metadata_routing=True)
@@ -640,6 +646,7 @@ def test_va_sample_weights_constant() -> None:
640646
np.testing.assert_allclose(probs_none, probs_ones, rtol=1e-2, atol=1e-2)
641647

642648

649+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
643650
def test_va_sample_weights_variable() -> None:
644651
"""Test that variable sample weights affect the results."""
645652
sklearn.set_config(enable_metadata_routing=True)
@@ -669,6 +676,7 @@ def test_va_sample_weights_variable() -> None:
669676
assert not np.allclose(probs_uniform, probs_weighted)
670677

671678

679+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
672680
def test_va_venn_abers_cv_with_sample_weight() -> None:
673681
"""Test VennAbersCV with sample weights in cross-validation mode."""
674682
sklearn.set_config(enable_metadata_routing=True)
@@ -689,6 +697,7 @@ def test_va_venn_abers_cv_with_sample_weight() -> None:
689697
assert np.allclose(probs.sum(axis=1), 1.0)
690698

691699

700+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
692701
def test_va_random_state_reproducibility() -> None:
693702
"""Test that random_state ensures reproducible results."""
694703
va_cal1 = VennAbersCalibrator(
@@ -715,6 +724,7 @@ def test_va_random_state_reproducibility() -> None:
715724
("calib_size", 0.4),
716725
],
717726
)
727+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
718728
def test_va_fit_parameters_override(param_name, override_value) -> None:
719729
"""Test that fit() parameters override constructor parameters."""
720730
va_cal = VennAbersCalibrator(
@@ -727,6 +737,7 @@ def test_va_fit_parameters_override(param_name, override_value) -> None:
727737

728738

729739
@pytest.mark.parametrize("cal_size", [0.2, 0.3, 0.5])
740+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
730741
def test_va_different_calibration_sizes(cal_size: float) -> None:
731742
"""Test that different calibration sizes work correctly."""
732743
va_cal = VennAbersCalibrator(
@@ -739,6 +750,7 @@ def test_va_different_calibration_sizes(cal_size: float) -> None:
739750

740751

741752
@pytest.mark.parametrize("n_splits", [2, 3, 5])
753+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
742754
def test_va_different_n_splits(n_splits: int) -> None:
743755
"""Test that different n_splits values work correctly."""
744756
va_cal = VennAbersCalibrator(
@@ -765,6 +777,7 @@ def test_va_n_splits_too_small_raises_error() -> None:
765777
va_cal.fit(X_binary_train, y_binary_train)
766778

767779

780+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
768781
def test_va_pipeline_compatibility() -> None:
769782
"""Test that VennAbersCalibrator works with sklearn pipelines."""
770783
X_df = pd.DataFrame(
@@ -805,6 +818,7 @@ def test_va_pipeline_compatibility() -> None:
805818
(pd.DataFrame, np.ndarray),
806819
],
807820
)
821+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
808822
def test_va_different_input_types(X_type, y_type) -> None:
809823
"""Test with different input data types."""
810824
X_train = X_type(X_binary_train) if X_type == pd.DataFrame else X_binary_train
@@ -856,6 +870,7 @@ def test_va_invalid_cal_size_raises_error(calib_size) -> None:
856870

857871

858872
@pytest.mark.parametrize("precision", [None, 2, 4])
873+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
859874
def test_va_precision_parameter(precision: Optional[int]) -> None:
860875
"""Test that precision parameter works correctly."""
861876
va_cal = VennAbersCalibrator(
@@ -870,6 +885,7 @@ def test_va_precision_parameter(precision: Optional[int]) -> None:
870885
assert np.allclose(probs.sum(axis=1), 1.0)
871886

872887

888+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
873889
def test_va_precision_parameter_multiclass() -> None:
874890
"""Test that precision parameter works correctly for multiclass."""
875891
va_cal = VennAbersCalibrator(
@@ -884,6 +900,7 @@ def test_va_precision_parameter_multiclass() -> None:
884900
assert np.allclose(probs.sum(axis=1), 1.0)
885901

886902

903+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
887904
def test_va_integration_with_cross_validation() -> None:
888905
"""Test integration with sklearn's cross-validation utilities."""
889906
from sklearn.model_selection import cross_val_score
@@ -927,6 +944,7 @@ def test_va_check_is_fitted_after_fit() -> None:
927944
check_is_fitted(va_cal)
928945

929946

947+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
930948
def test_va_predict_proba_prefitted_va_one_vs_all() -> None:
931949
"""Test predict_proba_prefitted_va with one_vs_all strategy."""
932950
X, y = make_classification(
@@ -967,6 +985,7 @@ def test_va_predict_proba_prefitted_va_invalid_type() -> None:
967985
predict_proba_prefitted_va(p_cal, y_train, p_test, va_tpe="invalid_type")
968986

969987

988+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
970989
def test_va_venn_abers_basic() -> None:
971990
"""Test basic VennAbers functionality for binary classification."""
972991
X, y = make_classification(n_samples=500, n_classes=2, random_state=random_state_va)
@@ -990,6 +1009,7 @@ def test_va_venn_abers_basic() -> None:
9901009
assert np.allclose(p_prime.sum(axis=1), 1.0)
9911010

9921011

1012+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
9931013
def test_va_venn_abers_cv_p0_p1_output() -> None:
9941014
"""Test VennAbersCV predict_proba with p0_p1_output=True."""
9951015
from mapie._venn_abers import VennAbersCV
@@ -1016,6 +1036,7 @@ def test_va_multiclass_cross_validation_requires_n_splits() -> None:
10161036
va_multi.fit(X_multi_train, y_multi_train)
10171037

10181038

1039+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
10191040
def test_va_multiclass_p0_p1_output() -> None:
10201041
"""Test VennAbersMultiClass with p0_p1_output=True."""
10211042
n_samples, n_features, n_classes = 100, 4, 3
@@ -1037,6 +1058,7 @@ def test_va_multiclass_p0_p1_output() -> None:
10371058
assert len(p0_p1_list) == n_classes * (n_classes - 1) // 2
10381059

10391060

1061+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
10401062
def test_va_prefit_predict_proba_without_single_estimator() -> None:
10411063
"""Test that predict_proba raises RuntimeError when single_estimator_ is None in prefit mode."""
10421064
clf = GaussianNB().fit(X_binary_proper, y_binary_proper)
@@ -1050,6 +1072,7 @@ def test_va_prefit_predict_proba_without_single_estimator() -> None:
10501072
va_cal.predict_proba(X_binary_test)
10511073

10521074

1075+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
10531076
def test_va_prefit_predict_proba_without_n_classes() -> None:
10541077
"""Test that predict_proba raises RuntimeError when n_classes_ is None after fitting."""
10551078
clf = GaussianNB().fit(X_binary_proper, y_binary_proper)
@@ -1092,6 +1115,7 @@ def test_va_prefit_classes_none_after_fitting() -> None:
10921115

10931116

10941117
@pytest.mark.parametrize("cv_ensemble", [True, False])
1118+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
10951119
def test_va_cv_ensemble_cross_binary(cv_ensemble) -> None:
10961120
"""Test cv_ensemble parameter with cross-validation mode."""
10971121
va_cal = VennAbersCalibrator(
@@ -1108,6 +1132,7 @@ def test_va_cv_ensemble_cross_binary(cv_ensemble) -> None:
11081132
assert np.allclose(proba.sum(axis=1), 1.0)
11091133

11101134

1135+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
11111136
def test_va_venn_abers_cv_brier_loss() -> None:
11121137
"""Test VennAbersCV with Brier loss."""
11131138
va_cal = VennAbersCalibrator(
@@ -1123,6 +1148,7 @@ def test_va_venn_abers_cv_brier_loss() -> None:
11231148
assert np.allclose(probs_brier.sum(axis=1), 1.0)
11241149

11251150

1151+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
11261152
def test_va_comprehensive_workflow() -> None:
11271153
"""Comprehensive test covering multiple aspects of VennAbersCalibrator."""
11281154
modes: list[tuple[str, dict[str, Any]]] = [
@@ -1171,6 +1197,7 @@ def test_va_comprehensive_workflow() -> None:
11711197
assert np.allclose(probs_prefit.sum(axis=1), 1.0)
11721198

11731199

1200+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
11741201
def test_va_prefit_binary_va_calibrator_none_raises() -> None:
11751202
clf = GaussianNB().fit(X_binary_proper, y_binary_proper)
11761203
va_cal = VennAbersCalibrator(estimator=clf, cv="prefit")
@@ -1209,6 +1236,7 @@ def test_va_inductive_va_calibrator_wrong_type_raises() -> None:
12091236
va_cal.predict_proba(X_binary_test)
12101237

12111238

1239+
@pytest.mark.filterwarnings("ignore:: RuntimeWarning")
12121240
def test_va_inductive_loss_branch_and_else_branch() -> None:
12131241
va_cal = VennAbersCalibrator(
12141242
estimator=GaussianNB(), inductive=True, random_state=random_state_va

mapie/tests/test_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,7 @@ def test_results_with_constant_sample_weights(strategy: str) -> None:
11961196

11971197

11981198
@pytest.mark.parametrize("strategy", [*STRATEGIES])
1199+
@pytest.mark.filterwarnings("ignore:.*The groups parameter is ignored.*:UserWarning")
11991200
def test_results_with_constant_groups(strategy: str) -> None:
12001201
"""
12011202
Test predictions when groups are None

mapie/tests/test_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ def test_results_with_constant_sample_weights(strategy: str) -> None:
553553

554554

555555
@pytest.mark.parametrize("strategy", [*STRATEGIES])
556+
@pytest.mark.filterwarnings("ignore:.*The groups parameter is ignored.*:UserWarning")
556557
def test_results_with_constant_groups(strategy: str) -> None:
557558
"""
558559
Test predictions when groups are None

0 commit comments

Comments
 (0)