|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import re |
3 | 4 | from typing import Any, Optional, Tuple |
4 | 5 |
|
5 | 6 | import numpy as np |
|
17 | 18 | check_array_inf, check_array_nan, check_arrays_length, |
18 | 19 | check_binary_zero_one, check_cv, check_gamma, |
19 | 20 | check_lower_upper_bounds, check_n_features_in, |
20 | | - check_n_jobs, check_no_agg_cv, check_null_weight, |
21 | | - check_number_bins, check_split_strategy, |
22 | | - check_verbose, compute_quantiles, fit_estimator, |
23 | | - get_binning_groups) |
| 21 | + check_n_jobs, check_n_samples, check_no_agg_cv, |
| 22 | + check_null_weight, check_number_bins, |
| 23 | + check_split_strategy, check_verbose, |
| 24 | + compute_quantiles, fit_estimator, get_binning_groups) |
24 | 25 |
|
25 | 26 | X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) |
26 | 27 | y_toy = np.array([5, 7, 9, 11, 13, 15]) |
@@ -508,3 +509,51 @@ def test_check_no_agg_cv_value_error(cv: Any) -> None: |
508 | 509 | match=r"Allowed values must have the `get_n_splits` method" |
509 | 510 | ): |
510 | 511 | check_no_agg_cv(X_toy, cv, array) |
| 512 | + |
| 513 | + |
| 514 | +@pytest.mark.parametrize("n_samples", [-4, -2, -1]) |
| 515 | +def test_invalid_n_samples_int_negative(n_samples: int) -> None: |
| 516 | + """Test that invalid n_samples raise errors.""" |
| 517 | + X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
| 518 | + indices = X.copy() |
| 519 | + with pytest.raises( |
| 520 | + ValueError, |
| 521 | + match=re.escape( |
| 522 | + r"Invalid n_samples. Allowed values " |
| 523 | + r"are float in the range (0.0, 1.0) or" |
| 524 | + r" int in the range [1, inf)" |
| 525 | + ) |
| 526 | + ): |
| 527 | + check_n_samples(X=X, n_samples=n_samples, indices=indices) |
| 528 | + |
| 529 | + |
| 530 | +@pytest.mark.parametrize("n_samples", [0.002, 0.003, 0.04]) |
| 531 | +def test_invalid_n_samples_int_zero(n_samples: int) -> None: |
| 532 | + """Test that invalid n_samples raise errors.""" |
| 533 | + X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
| 534 | + indices = X.copy() |
| 535 | + with pytest.raises( |
| 536 | + ValueError, |
| 537 | + match=re.escape( |
| 538 | + r"The value of n_samples is too small. " |
| 539 | + r"You need to increase it so that n_samples*X.shape[0] > 1" |
| 540 | + r"otherwise n_samples should be an int" |
| 541 | + ) |
| 542 | + ): |
| 543 | + check_n_samples(X=X, n_samples=n_samples, indices=indices) |
| 544 | + |
| 545 | + |
| 546 | +@pytest.mark.parametrize("n_samples", [-5.5, -4.3, -0.2, 1.2, 2.5, 3.4]) |
| 547 | +def test_invalid_n_samples_float(n_samples: float) -> None: |
| 548 | + """Test that invalid n_samples raise errors.""" |
| 549 | + X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
| 550 | + indices = X.copy() |
| 551 | + with pytest.raises( |
| 552 | + ValueError, |
| 553 | + match=re.escape( |
| 554 | + r"Invalid n_samples. Allowed values " |
| 555 | + r"are float in the range (0.0, 1.0) or" |
| 556 | + r" int in the range [1, inf)" |
| 557 | + ) |
| 558 | + ): |
| 559 | + check_n_samples(X=X, n_samples=n_samples, indices=indices) |
0 commit comments