Skip to content

Commit 4c25001

Browse files
Merge pull request #464 from scikit-learn-contrib/238-giving-a-fraction-of-samples-instead-of-a-number-of-samples-in-the-subsample-class
238 giving a fraction of samples instead of a number of samples in the subsample class
2 parents 9ffb70b + 0085eac commit 4c25001

File tree

5 files changed

+158
-10
lines changed

5 files changed

+158
-10
lines changed

HISTORY.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ History
55
0.8.x (2024-xx-xx)
66
------------------
77

8+
* Building a training set with a fraction between 0 and 1 with `n_samples` attribute when using `split` method from `Subsample` class.
9+
810
0.8.6 (2024-06-14)
911
------------------
1012

mapie/subsample.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.utils.validation import _num_samples
1111

1212
from ._typing import NDArray
13+
from .utils import check_n_samples
1314

1415

1516
class Subsample(BaseCrossValidator):
@@ -22,9 +23,10 @@ class Subsample(BaseCrossValidator):
2223
----------
2324
n_resamplings : int
2425
Number of resamplings. By default ``30``.
25-
n_samples: int
26+
n_samples: Union[int, float]
2627
Number of samples in each resampling. By default ``None``,
27-
the size of the training set.
28+
the size of the training set. If it is between 0 and 1,
29+
it becomes the fraction of samples
2830
replace: bool
2931
Whether to replace samples in resamplings or not. By default ``True``.
3032
random_state: Optional[Union[int, RandomState]]
@@ -46,7 +48,7 @@ class Subsample(BaseCrossValidator):
4648
def __init__(
4749
self,
4850
n_resamplings: int = 30,
49-
n_samples: Optional[int] = None,
51+
n_samples: Optional[Union[int, float]] = None,
5052
replace: bool = True,
5153
random_state: Optional[Union[int, RandomState]] = None,
5254
) -> None:
@@ -74,9 +76,7 @@ def split(
7476
The testing set indices for that split.
7577
"""
7678
indices = np.arange(_num_samples(X))
77-
n_samples = (
78-
self.n_samples if self.n_samples is not None else len(indices)
79-
)
79+
n_samples = check_n_samples(X, self.n_samples, indices)
8080
random_state = check_random_state(self.random_state)
8181
for k in range(self.n_resamplings):
8282
train_index = resample(

mapie/tests/test_subsample.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,50 @@ def test_split_SubSample() -> None:
3232
np.testing.assert_equal(tests, tests_expected)
3333

3434

35+
@pytest.mark.parametrize("n_samples", [4, 6, 8, 10])
36+
@pytest.mark.parametrize("n_resamplings", [1, 2, 3])
37+
def test_n_samples_int(n_samples: int,
38+
n_resamplings: int) -> None:
39+
"""Test outputs of subsamplings when n_samples is a int"""
40+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
41+
cv = Subsample(n_resamplings=n_resamplings, random_state=0,
42+
n_samples=n_samples, replace=False)
43+
train_set = np.concatenate([x[0] for x in cv.split(X)])
44+
val_set = np.concatenate([x[1] for x in cv.split(X)])
45+
assert len(train_set) == n_samples*n_resamplings
46+
assert len(val_set) == (X.shape[0] - n_samples)*n_resamplings
47+
48+
49+
@pytest.mark.parametrize("n_samples", [0.4, 0.6, 0.8, 0.9])
50+
@pytest.mark.parametrize("n_resamplings", [1, 2, 3])
51+
def test_n_samples_float(n_samples: float,
52+
n_resamplings: int) -> None:
53+
"""Test outputs of subsamplings when n_samples is a
54+
float between 0 and 1."""
55+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
56+
cv = Subsample(n_resamplings=n_resamplings, random_state=0,
57+
n_samples=n_samples, replace=False)
58+
train_set = np.concatenate([x[0] for x in cv.split(X)])
59+
val_set = np.concatenate([x[1] for x in cv.split(X)])
60+
assert len(train_set) == int(np.floor(n_samples*X.shape[0]))*n_resamplings
61+
assert len(val_set) == (
62+
(X.shape[0] - int(np.floor(n_samples * X.shape[0]))) *
63+
n_resamplings
64+
)
65+
66+
67+
@pytest.mark.parametrize("n_resamplings", [1, 2, 3])
68+
def test_n_samples_none(n_resamplings: int) -> None:
69+
"""Test outputs of subsamplings when n_samples is None."""
70+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
71+
cv = Subsample(n_resamplings=n_resamplings, random_state=0,
72+
replace=False)
73+
train_set = np.concatenate([x[0] for x in cv.split(X)])
74+
val_set = np.concatenate([x[1] for x in cv.split(X)])
75+
assert len(train_set) == X.shape[0]*n_resamplings
76+
assert len(val_set) == 0
77+
78+
3579
def test_default_parameters_BlockBootstrap() -> None:
3680
"""Test default values of Subsample."""
3781
cv = BlockBootstrap()

mapie/tests/test_utils.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
from typing import Any, Optional, Tuple
45

56
import numpy as np
@@ -17,10 +18,10 @@
1718
check_array_inf, check_array_nan, check_arrays_length,
1819
check_binary_zero_one, check_cv, check_gamma,
1920
check_lower_upper_bounds, check_n_features_in,
20-
check_n_jobs, check_no_agg_cv, check_null_weight,
21-
check_number_bins, check_split_strategy,
22-
check_verbose, compute_quantiles, fit_estimator,
23-
get_binning_groups)
21+
check_n_jobs, check_n_samples, check_no_agg_cv,
22+
check_null_weight, check_number_bins,
23+
check_split_strategy, check_verbose,
24+
compute_quantiles, fit_estimator, get_binning_groups)
2425

2526
X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1)
2627
y_toy = np.array([5, 7, 9, 11, 13, 15])
@@ -508,3 +509,51 @@ def test_check_no_agg_cv_value_error(cv: Any) -> None:
508509
match=r"Allowed values must have the `get_n_splits` method"
509510
):
510511
check_no_agg_cv(X_toy, cv, array)
512+
513+
514+
@pytest.mark.parametrize("n_samples", [-4, -2, -1])
515+
def test_invalid_n_samples_int_negative(n_samples: int) -> None:
516+
"""Test that invalid n_samples raise errors."""
517+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
518+
indices = X.copy()
519+
with pytest.raises(
520+
ValueError,
521+
match=re.escape(
522+
r"Invalid n_samples. Allowed values "
523+
r"are float in the range (0.0, 1.0) or"
524+
r" int in the range [1, inf)"
525+
)
526+
):
527+
check_n_samples(X=X, n_samples=n_samples, indices=indices)
528+
529+
530+
@pytest.mark.parametrize("n_samples", [0.002, 0.003, 0.04])
531+
def test_invalid_n_samples_int_zero(n_samples: int) -> None:
532+
"""Test that invalid n_samples raise errors."""
533+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
534+
indices = X.copy()
535+
with pytest.raises(
536+
ValueError,
537+
match=re.escape(
538+
r"The value of n_samples is too small. "
539+
r"You need to increase it so that n_samples*X.shape[0] > 1"
540+
r"otherwise n_samples should be an int"
541+
)
542+
):
543+
check_n_samples(X=X, n_samples=n_samples, indices=indices)
544+
545+
546+
@pytest.mark.parametrize("n_samples", [-5.5, -4.3, -0.2, 1.2, 2.5, 3.4])
547+
def test_invalid_n_samples_float(n_samples: float) -> None:
548+
"""Test that invalid n_samples raise errors."""
549+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
550+
indices = X.copy()
551+
with pytest.raises(
552+
ValueError,
553+
match=re.escape(
554+
r"Invalid n_samples. Allowed values "
555+
r"are float in the range (0.0, 1.0) or"
556+
r" int in the range [1, inf)"
557+
)
558+
):
559+
check_n_samples(X=X, n_samples=n_samples, indices=indices)

mapie/utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,3 +1355,56 @@ def check_arrays_length(*arrays: NDArray) -> None:
13551355
raise ValueError(
13561356
"There are arrays with different length"
13571357
)
1358+
1359+
1360+
def check_n_samples(
1361+
X: NDArray,
1362+
n_samples: Optional[Union[float, int]],
1363+
indices: NDArray
1364+
) -> int:
1365+
"""
1366+
Check alpha and prepare it as a ArrayLike.
1367+
1368+
Parameters
1369+
----------
1370+
n_samples: Union[float, int]
1371+
Can be a float between 0 and 1 or a int
1372+
Between 0 and 1, represent the part of data in the train sample
1373+
When n_samples is a int, it represents the number of elements
1374+
in the train sample
1375+
1376+
Returns
1377+
-------
1378+
int
1379+
n_samples
1380+
1381+
Raises
1382+
------
1383+
ValueError
1384+
If n_samples is not an int in the range [1, inf)
1385+
or a float in the range (0.0, 1.0)
1386+
"""
1387+
if n_samples is None:
1388+
n_samples = len(indices)
1389+
elif isinstance(n_samples, float):
1390+
if 0 < n_samples < 1:
1391+
n_samples = int(np.floor(n_samples * X.shape[0]))
1392+
if n_samples == 0:
1393+
raise ValueError(
1394+
"The value of n_samples is too small. "
1395+
"You need to increase it so that n_samples*X.shape[0] > 1"
1396+
"otherwise n_samples should be an int"
1397+
)
1398+
else:
1399+
raise ValueError(
1400+
"Invalid n_samples. Allowed values "
1401+
"are float in the range (0.0, 1.0) or"
1402+
" int in the range [1, inf)"
1403+
)
1404+
elif isinstance(n_samples, int) and n_samples <= 0:
1405+
raise ValueError(
1406+
"Invalid n_samples. Allowed values "
1407+
"are float in the range (0.0, 1.0) or"
1408+
" int in the range [1, inf)"
1409+
)
1410+
return int(n_samples)

0 commit comments

Comments
 (0)