|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from itertools import combinations, product |
| 4 | +from typing import Union |
| 5 | + |
3 | 6 | import numpy as np |
4 | 7 | import pytest |
5 | 8 |
|
@@ -76,6 +79,45 @@ def test_n_samples_none(n_resamplings: int) -> None: |
76 | 79 | assert len(val_set) == 0 |
77 | 80 |
|
78 | 81 |
|
| 82 | +@pytest.mark.parametrize("n_samples", [0.4, 0.6, 3, 6]) |
| 83 | +@pytest.mark.parametrize("n_resamplings", [2, 3, 4]) |
| 84 | +def test_split_samples_Subsample(n_resamplings: int, |
| 85 | + n_samples: Union[int, float]) -> None: |
| 86 | + """Test that outputs of subsamplings are all different.""" |
| 87 | + X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
| 88 | + cv = Subsample(n_resamplings=n_resamplings, |
| 89 | + n_samples=n_samples, replace=False, random_state=0) |
| 90 | + trains = [x[0] for x in cv.split(X)] |
| 91 | + tests = [x[1] for x in cv.split(X)] |
| 92 | + for (train1, train2), (test1, test2) in product( |
| 93 | + combinations(trains, 2), combinations(tests, 2)): |
| 94 | + with np.testing.assert_raises(AssertionError): |
| 95 | + np.testing.assert_equal(train1, train2) |
| 96 | + with np.testing.assert_raises(AssertionError): |
| 97 | + np.testing.assert_equal(test1, test2) |
| 98 | + |
| 99 | + |
| 100 | +@pytest.mark.parametrize("n_samples", [0.4, 0.6, 3, 6]) |
| 101 | +@pytest.mark.parametrize("n_resamplings", [2, 3, 4]) |
| 102 | +def test_reproductibility_samples_Subsample( |
| 103 | + n_resamplings: int, |
| 104 | + n_samples: Union[int, float] |
| 105 | +) -> None: |
| 106 | + """This test ensures that each split between |
| 107 | + two instances is the same for a given seed.""" |
| 108 | + X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
| 109 | + cv1 = Subsample(n_resamplings=n_resamplings, |
| 110 | + n_samples=n_samples, replace=False, random_state=0) |
| 111 | + trains1 = [x[0] for x in cv1.split(X)] |
| 112 | + tests1 = [x[1] for x in cv1.split(X)] |
| 113 | + cv2 = Subsample(n_resamplings=n_resamplings, |
| 114 | + n_samples=n_samples, replace=False, random_state=0) |
| 115 | + trains2 = [x[0] for x in cv2.split(X)] |
| 116 | + tests2 = [x[1] for x in cv2.split(X)] |
| 117 | + np.testing.assert_array_equal(trains1, trains2) |
| 118 | + np.testing.assert_array_equal(tests1, tests2) |
| 119 | + |
| 120 | + |
79 | 121 | def test_default_parameters_BlockBootstrap() -> None: |
80 | 122 | """Test default values of Subsample.""" |
81 | 123 | cv = BlockBootstrap() |
@@ -131,3 +173,47 @@ def test_split_BlockBootstrap_error() -> None: |
131 | 173 | cv = BlockBootstrap() |
132 | 174 | with pytest.raises(ValueError, match=r".*Exactly one argument*"): |
133 | 175 | next(cv.split(X)) |
| 176 | + |
| 177 | + |
| 178 | +@pytest.mark.parametrize("length", [2, 3, 4]) |
| 179 | +@pytest.mark.parametrize("n_resamplings", [2, 3, 4]) |
| 180 | +def test_split_samples_BlockBootstrap(n_resamplings: int, |
| 181 | + length: int) -> None: |
| 182 | + """Test that outputs of subsamplings are all different.""" |
| 183 | + X = np.arange(31) |
| 184 | + cv = BlockBootstrap(n_resamplings=n_resamplings, |
| 185 | + length=length, random_state=0) |
| 186 | + trains = [x[0] for x in cv.split(X)] |
| 187 | + tests = [x[1] for x in cv.split(X)] |
| 188 | + for (train1, train2), (test1, test2) in product( |
| 189 | + combinations(trains, 2), combinations(tests, 2)): |
| 190 | + with np.testing.assert_raises(AssertionError): |
| 191 | + np.testing.assert_equal(train1, train2) |
| 192 | + with np.testing.assert_raises(AssertionError): |
| 193 | + np.testing.assert_equal(test1, test2) |
| 194 | + |
| 195 | + |
| 196 | +@pytest.mark.parametrize("length", [2, 3, 4]) |
| 197 | +@pytest.mark.parametrize("n_resamplings", [2, 3, 4]) |
| 198 | +def test_reproductibility_samples_BlockBootstrap( |
| 199 | + n_resamplings: int, |
| 200 | + length: int) -> None: |
| 201 | + """This test ensures that each split between |
| 202 | + two instances is the same for a given seed.""" |
| 203 | + X = np.arange(15) |
| 204 | + cv1 = BlockBootstrap( |
| 205 | + n_resamplings=n_resamplings, |
| 206 | + length=length, |
| 207 | + random_state=42 |
| 208 | + ) |
| 209 | + trains1 = [x[0] for x in list(cv1.split(X))] |
| 210 | + tests1 = [x[1] for x in list(cv1.split(X))] |
| 211 | + cv2 = BlockBootstrap( |
| 212 | + n_resamplings=n_resamplings, |
| 213 | + length=length, |
| 214 | + random_state=42 |
| 215 | + ) |
| 216 | + trains2 = [x[0] for x in list(cv2.split(X))] |
| 217 | + tests2 = [x[1] for x in list(cv2.split(X))] |
| 218 | + np.testing.assert_equal(trains1, trains2) |
| 219 | + np.testing.assert_equal(tests1, tests2) |
0 commit comments