Skip to content

Commit 5fca040

Browse files
committed
Update4 : taking comments into account
1 parent 7c6621f commit 5fca040

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

mapie/tests/test_utils.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import re
34
from typing import Any, Optional, Tuple
5+
46
import numpy as np
57
import pytest
6-
import re
78
from numpy.random import RandomState
89
from sklearn.datasets import make_regression
910
from sklearn.linear_model import LinearRegression
@@ -17,11 +18,10 @@
1718
check_array_inf, check_array_nan, check_arrays_length,
1819
check_binary_zero_one, check_cv, check_gamma,
1920
check_lower_upper_bounds, check_n_features_in,
20-
check_n_jobs, check_no_agg_cv, check_n_samples,
21-
check_null_weight,
22-
check_number_bins, check_split_strategy,
23-
check_verbose, compute_quantiles, fit_estimator,
24-
get_binning_groups)
21+
check_n_jobs, check_n_samples, check_no_agg_cv,
22+
check_null_weight, check_number_bins,
23+
check_split_strategy, check_verbose,
24+
compute_quantiles, fit_estimator, get_binning_groups)
2525

2626
X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1)
2727
y_toy = np.array([5, 7, 9, 11, 13, 15])
@@ -543,24 +543,8 @@ def test_invalid_n_samples_int_zero(n_samples: int) -> None:
543543
check_n_samples(X=X, n_samples=n_samples, indices=indices)
544544

545545

546-
@pytest.mark.parametrize("n_samples", [-5.5, -4.3, -0.2])
547-
def test_invalid_n_samples_float_negative(n_samples: float) -> None:
548-
"""Test that invalid n_samples raise errors."""
549-
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
550-
indices = X.copy()
551-
with pytest.raises(
552-
ValueError,
553-
match=re.escape(
554-
r"Invalid n_samples. Allowed values "
555-
r"are float in the range (0.0, 1.0) or"
556-
r" int in the range [1, inf)"
557-
)
558-
):
559-
check_n_samples(X=X, n_samples=n_samples, indices=indices)
560-
561-
562-
@pytest.mark.parametrize("n_samples", [1.2, 2.5, 3.4])
563-
def test_invalid_n_samples_float_greater_than_1(n_samples: float) -> None:
546+
@pytest.mark.parametrize("n_samples", [-5.5, -4.3, -0.2, 1.2, 2.5, 3.4])
547+
def test_invalid_n_samples_float(n_samples: float) -> None:
564548
"""Test that invalid n_samples raise errors."""
565549
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
566550
indices = X.copy()

0 commit comments

Comments
 (0)