|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import re |
3 | 4 | from typing import Any, Optional, Tuple |
| 5 | + |
4 | 6 | import numpy as np |
5 | 7 | import pytest |
6 | | -import re |
7 | 8 | from numpy.random import RandomState |
8 | 9 | from sklearn.datasets import make_regression |
9 | 10 | from sklearn.linear_model import LinearRegression |
|
17 | 18 | check_array_inf, check_array_nan, check_arrays_length, |
18 | 19 | check_binary_zero_one, check_cv, check_gamma, |
19 | 20 | check_lower_upper_bounds, check_n_features_in, |
20 | | - check_n_jobs, check_no_agg_cv, check_n_samples, |
21 | | - check_null_weight, |
22 | | - check_number_bins, check_split_strategy, |
23 | | - check_verbose, compute_quantiles, fit_estimator, |
24 | | - get_binning_groups) |
| 21 | + check_n_jobs, check_n_samples, check_no_agg_cv, |
| 22 | + check_null_weight, check_number_bins, |
| 23 | + check_split_strategy, check_verbose, |
| 24 | + compute_quantiles, fit_estimator, get_binning_groups) |
25 | 25 |
|
26 | 26 | X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) |
27 | 27 | y_toy = np.array([5, 7, 9, 11, 13, 15]) |
@@ -543,24 +543,8 @@ def test_invalid_n_samples_int_zero(n_samples: int) -> None: |
543 | 543 | check_n_samples(X=X, n_samples=n_samples, indices=indices) |
544 | 544 |
|
545 | 545 |
|
546 | | -@pytest.mark.parametrize("n_samples", [-5.5, -4.3, -0.2]) |
547 | | -def test_invalid_n_samples_float_negative(n_samples: float) -> None: |
548 | | - """Test that invalid n_samples raise errors.""" |
549 | | - X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
550 | | - indices = X.copy() |
551 | | - with pytest.raises( |
552 | | - ValueError, |
553 | | - match=re.escape( |
554 | | - r"Invalid n_samples. Allowed values " |
555 | | - r"are float in the range (0.0, 1.0) or" |
556 | | - r" int in the range [1, inf)" |
557 | | - ) |
558 | | - ): |
559 | | - check_n_samples(X=X, n_samples=n_samples, indices=indices) |
560 | | - |
561 | | - |
562 | | -@pytest.mark.parametrize("n_samples", [1.2, 2.5, 3.4]) |
563 | | -def test_invalid_n_samples_float_greater_than_1(n_samples: float) -> None: |
| 546 | +@pytest.mark.parametrize("n_samples", [-5.5, -4.3, -0.2, 1.2, 2.5, 3.4]) |
| 547 | +def test_invalid_n_samples_float(n_samples: float) -> None: |
564 | 548 | """Test that invalid n_samples raise errors.""" |
565 | 549 | X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
566 | 550 | indices = X.copy() |
|
0 commit comments