Skip to content

Commit a8f80a6

Browse files
Merge pull request #468 from scikit-learn-contrib/290-unit-tests-for-different-subsamples
Unit tests for different subsamples
2 parents e017317 + c41a484 commit a8f80a6

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

HISTORY.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ History
55
0.8.x (2024-xx-xx)
66
------------------
77

8+
* Building unit tests for different `Subsample` and `BlockBooststrap` instances
89
* Change the sign of C_k in the `Kolmogorov-Smirnov` test documentation
910
* Building a training set with a fraction between 0 and 1 with `n_samples` attribute when using `split` method from `Subsample` class.
1011

mapie/tests/test_subsample.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from itertools import combinations, product
4+
from typing import Union
5+
36
import numpy as np
47
import pytest
58

@@ -76,6 +79,45 @@ def test_n_samples_none(n_resamplings: int) -> None:
7679
assert len(val_set) == 0
7780

7881

82+
@pytest.mark.parametrize("n_samples", [0.4, 0.6, 3, 6])
83+
@pytest.mark.parametrize("n_resamplings", [2, 3, 4])
84+
def test_split_samples_Subsample(n_resamplings: int,
85+
n_samples: Union[int, float]) -> None:
86+
"""Test that outputs of subsamplings are all different."""
87+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
88+
cv = Subsample(n_resamplings=n_resamplings,
89+
n_samples=n_samples, replace=False, random_state=0)
90+
trains = [x[0] for x in cv.split(X)]
91+
tests = [x[1] for x in cv.split(X)]
92+
for (train1, train2), (test1, test2) in product(
93+
combinations(trains, 2), combinations(tests, 2)):
94+
with np.testing.assert_raises(AssertionError):
95+
np.testing.assert_equal(train1, train2)
96+
with np.testing.assert_raises(AssertionError):
97+
np.testing.assert_equal(test1, test2)
98+
99+
100+
@pytest.mark.parametrize("n_samples", [0.4, 0.6, 3, 6])
101+
@pytest.mark.parametrize("n_resamplings", [2, 3, 4])
102+
def test_reproductibility_samples_Subsample(
103+
n_resamplings: int,
104+
n_samples: Union[int, float]
105+
) -> None:
106+
"""This test ensures that each split between
107+
two instances is the same for a given seed."""
108+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
109+
cv1 = Subsample(n_resamplings=n_resamplings,
110+
n_samples=n_samples, replace=False, random_state=0)
111+
trains1 = [x[0] for x in cv1.split(X)]
112+
tests1 = [x[1] for x in cv1.split(X)]
113+
cv2 = Subsample(n_resamplings=n_resamplings,
114+
n_samples=n_samples, replace=False, random_state=0)
115+
trains2 = [x[0] for x in cv2.split(X)]
116+
tests2 = [x[1] for x in cv2.split(X)]
117+
np.testing.assert_array_equal(trains1, trains2)
118+
np.testing.assert_array_equal(tests1, tests2)
119+
120+
79121
def test_default_parameters_BlockBootstrap() -> None:
80122
"""Test default values of Subsample."""
81123
cv = BlockBootstrap()
@@ -131,3 +173,47 @@ def test_split_BlockBootstrap_error() -> None:
131173
cv = BlockBootstrap()
132174
with pytest.raises(ValueError, match=r".*Exactly one argument*"):
133175
next(cv.split(X))
176+
177+
178+
@pytest.mark.parametrize("length", [2, 3, 4])
179+
@pytest.mark.parametrize("n_resamplings", [2, 3, 4])
180+
def test_split_samples_BlockBootstrap(n_resamplings: int,
181+
length: int) -> None:
182+
"""Test that outputs of subsamplings are all different."""
183+
X = np.arange(31)
184+
cv = BlockBootstrap(n_resamplings=n_resamplings,
185+
length=length, random_state=0)
186+
trains = [x[0] for x in cv.split(X)]
187+
tests = [x[1] for x in cv.split(X)]
188+
for (train1, train2), (test1, test2) in product(
189+
combinations(trains, 2), combinations(tests, 2)):
190+
with np.testing.assert_raises(AssertionError):
191+
np.testing.assert_equal(train1, train2)
192+
with np.testing.assert_raises(AssertionError):
193+
np.testing.assert_equal(test1, test2)
194+
195+
196+
@pytest.mark.parametrize("length", [2, 3, 4])
197+
@pytest.mark.parametrize("n_resamplings", [2, 3, 4])
198+
def test_reproductibility_samples_BlockBootstrap(
199+
n_resamplings: int,
200+
length: int) -> None:
201+
"""This test ensures that each split between
202+
two instances is the same for a given seed."""
203+
X = np.arange(15)
204+
cv1 = BlockBootstrap(
205+
n_resamplings=n_resamplings,
206+
length=length,
207+
random_state=42
208+
)
209+
trains1 = [x[0] for x in list(cv1.split(X))]
210+
tests1 = [x[1] for x in list(cv1.split(X))]
211+
cv2 = BlockBootstrap(
212+
n_resamplings=n_resamplings,
213+
length=length,
214+
random_state=42
215+
)
216+
trains2 = [x[0] for x in list(cv2.split(X))]
217+
tests2 = [x[1] for x in list(cv2.split(X))]
218+
np.testing.assert_equal(trains1, trains2)
219+
np.testing.assert_equal(tests1, tests2)

0 commit comments

Comments
 (0)