Skip to content

Commit 364df9a

Browse files
ENH: change warning into infolog (#584)
ENH: change warning into infolog
1 parent 4561bcc commit 364df9a

File tree

2 files changed

+13
-52
lines changed

2 files changed

+13
-52
lines changed

mapie/tests/test_utils.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
import re
45
from typing import Any, Optional, Tuple
56

@@ -228,22 +229,24 @@ def test_valid_verbose(verbose: Any) -> None:
228229
check_verbose(verbose)
229230

230231

231-
def test_initial_low_high_pred() -> None:
232+
def test_initial_low_high_pred(caplog) -> None:
232233
"""Test lower/upper predictions of the quantiles regression crossing"""
233234
y_preds = np.array([[4, 5, 2], [4, 4, 4], [2, 3, 4]])
234-
with pytest.warns(UserWarning, match=r"WARNING: The predictions are*"):
235+
with caplog.at_level(logging.INFO):
235236
check_lower_upper_bounds(y_preds[0], y_preds[1], y_preds[2])
237+
assert "The predictions are ill-sorted" in caplog.text
236238

237239

238-
def test_final_low_high_pred() -> None:
240+
def test_final_low_high_pred(caplog) -> None:
239241
"""Test lower/upper predictions crossing"""
240242
y_preds = np.array(
241243
[[4, 3, 2], [3, 3, 3], [2, 3, 4]]
242244
)
243245
y_pred_low = np.array([4, 7, 2])
244246
y_pred_up = np.array([3, 3, 3])
245-
with pytest.warns(UserWarning, match=r"WARNING: The predictions are*"):
247+
with caplog.at_level(logging.INFO):
246248
check_lower_upper_bounds(y_pred_low, y_pred_up, y_preds[2])
249+
assert "The predictions are ill-sorted" in caplog.text
247250

248251

249252
def test_ensemble_in_predict() -> None:
@@ -331,19 +334,6 @@ def test_quantile_prefit_non_iterable(estimator: Any) -> None:
331334
mapie_reg.fit([1, 2, 3], [4, 5, 6])
332335

333336

334-
# def test_calib_set_no_Xy_but_sample_weight() -> None:
335-
# """Test warning message if sample weight provided but no X y in calib."""
336-
# X = np.array([4, 5, 6])
337-
# y = np.array([4, 3, 2])
338-
# sample_weight = np.array([4, 4, 4])
339-
# sample_weight_calib = np.array([4, 3, 4])
340-
# with pytest.warns(UserWarning, match=r"WARNING: sample weight*"):
341-
# check_calib_set(
342-
# X=X, y=y, sample_weight=sample_weight,
343-
# sample_weight_calib=sample_weight_calib
344-
# )
345-
346-
347337
@pytest.mark.parametrize("strategy", ["quantile", "uniform", "array split"])
348338
def test_binning_group_strategies(strategy: str) -> None:
349339
"""Test that different strategies have the correct outputs."""

mapie/utils.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import warnings
23
from inspect import signature
34
from typing import Any, Iterable, Optional, Tuple, Union, cast
@@ -573,39 +574,6 @@ def check_lower_upper_bounds(
573574
y_pred_up: NDArray,
574575
y_preds: NDArray
575576
) -> None:
576-
"""
577-
Check if lower or upper bounds and prediction are consistent.
578-
579-
Parameters
580-
----------
581-
y_pred_low: NDArray of shape (n_samples,)
582-
Lower bound prediction.
583-
584-
y_pred_up: NDArray of shape (n_samples,)
585-
Upper bound prediction.
586-
587-
y_preds: NDArray of shape (n_samples,)
588-
Prediction.
589-
590-
Raises
591-
------
592-
Warning
593-
If any of the predictions are ill-sorted.
594-
595-
Examples
596-
--------
597-
>>> import warnings
598-
>>> warnings.filterwarnings("error")
599-
>>> import numpy as np
600-
>>> from mapie.utils import check_lower_upper_bounds
601-
>>> y_preds = np.array([[4, 3, 2], [4, 4, 4], [2, 3, 4]])
602-
>>> try:
603-
... check_lower_upper_bounds(y_preds[0], y_preds[1], y_preds[2])
604-
... except Exception as exception:
605-
... print(exception)
606-
...
607-
WARNING: The predictions are ill-sorted.
608-
"""
609577
y_pred_low = column_or_1d(y_pred_low)
610578
y_pred_up = column_or_1d(y_pred_up)
611579
y_preds = column_or_1d(y_preds)
@@ -617,9 +585,12 @@ def check_lower_upper_bounds(
617585
)
618586

619587
if any_inversion:
620-
warnings.warn(
621-
"WARNING: The predictions are ill-sorted."
588+
initial_logger_level = logging.root.level
589+
logging.basicConfig(level=logging.INFO)
590+
logging.info(
591+
"The predictions are ill-sorted."
622592
)
593+
logging.basicConfig(level=initial_logger_level)
623594

624595

625596
def check_defined_variables_predict_cqr(

0 commit comments

Comments
 (0)