9
9
10
10
from sklearn .datasets import load_iris
11
11
from sklearn .utils .testing import assert_array_equal
12
+ from sklearn .utils .testing import assert_allclose_dense_sparse
12
13
13
14
from imblearn .datasets import make_imbalance
14
15
from imblearn .misc import FunctionSampler
25
26
def test_function_sampler_identity (X , y ):
26
27
sampler = FunctionSampler ()
27
28
X_res , y_res = sampler .fit_sample (X , y )
28
- if sparse .issparse (X ):
29
- X = X .toarray ()
30
- X_res = X_res .toarray ()
31
- assert_array_equal (X_res , X )
29
+ assert_allclose_dense_sparse (X_res , X )
32
30
assert_array_equal (y_res , y )
33
31
34
32
@@ -42,10 +40,7 @@ def func(X, y):
42
40
43
41
sampler = FunctionSampler (func = func )
44
42
X_res , y_res = sampler .fit_sample (X , y )
45
- if sparse .issparse (X ):
46
- X = X .toarray ()
47
- X_res = X_res .toarray ()
48
- assert_array_equal (X_res , X [:10 ])
43
+ assert_allclose_dense_sparse (X_res , X [:10 ])
49
44
assert_array_equal (y_res , y [:10 ])
50
45
51
46
@@ -62,9 +57,5 @@ def func(X, y, ratio, random_state):
62
57
'random_state' : 0 })
63
58
X_res , y_res = sampler .fit_sample (X , y )
64
59
X_res_2 , y_res_2 = RandomUnderSampler (random_state = 0 ).fit_sample (X , y )
65
- if sparse .issparse (X ):
66
- X = X .toarray ()
67
- X_res = X_res .toarray ()
68
- X_res_2 = X_res_2 .toarray ()
69
- assert_array_equal (X_res , X_res_2 )
60
+ assert_allclose_dense_sparse (X_res , X_res_2 )
70
61
assert_array_equal (y_res , y_res_2 )
0 commit comments