3
3
# Authors: Guillaume Lemaitre <[email protected] >
4
4
# License: MIT
5
5
6
+ import pytest
7
+
8
+ from scipy import sparse
9
+
6
10
from sklearn .datasets import load_iris
7
11
from sklearn .utils .testing import assert_array_equal
8
12
15
19
random_state = 0 )
16
20
17
21
18
- def test_function_sampler_identity ():
22
+ @pytest .mark .parametrize ("X,y" , [(X , y ),
23
+ (sparse .csr_matrix (X ), y ),
24
+ (sparse .csc_matrix (X ), y )])
25
+ def test_function_sampler_identity (X , y ):
19
26
sampler = FunctionSampler ()
20
27
X_res , y_res = sampler .fit_sample (X , y )
28
+ if sparse .issparse (X ):
29
+ X = X .toarray ()
30
+ X_res = X_res .toarray ()
21
31
assert_array_equal (X_res , X )
22
32
assert_array_equal (y_res , y )
23
33
24
34
25
- def test_function_sampler_func ():
35
+ @pytest .mark .parametrize ("X,y" , [(X , y ),
36
+ (sparse .csr_matrix (X ), y ),
37
+ (sparse .csc_matrix (X ), y )])
38
+ def test_function_sampler_func (X , y ):
26
39
27
40
def func (X , y ):
28
41
return X [:10 ], y [:10 ]
29
42
30
43
sampler = FunctionSampler (func = func )
31
44
X_res , y_res = sampler .fit_sample (X , y )
45
+ if sparse .issparse (X ):
46
+ X = X .toarray ()
47
+ X_res = X_res .toarray ()
32
48
assert_array_equal (X_res , X [:10 ])
33
49
assert_array_equal (y_res , y [:10 ])
34
50
35
51
36
- def test_function_sampler_func_kwargs ():
52
+ @pytest .mark .parametrize ("X,y" , [(X , y ),
53
+ (sparse .csr_matrix (X ), y ),
54
+ (sparse .csc_matrix (X ), y )])
55
+ def test_function_sampler_func_kwargs (X , y ):
37
56
38
57
def func (X , y , ratio , random_state ):
39
58
rus = RandomUnderSampler (ratio = ratio , random_state = random_state )
@@ -43,5 +62,9 @@ def func(X, y, ratio, random_state):
43
62
'random_state' : 0 })
44
63
X_res , y_res = sampler .fit_sample (X , y )
45
64
X_res_2 , y_res_2 = RandomUnderSampler (random_state = 0 ).fit_sample (X , y )
65
+ if sparse .issparse (X ):
66
+ X = X .toarray ()
67
+ X_res = X_res .toarray ()
68
+ X_res_2 = X_res_2 .toarray ()
46
69
assert_array_equal (X_res , X_res_2 )
47
70
assert_array_equal (y_res , y_res_2 )
0 commit comments