Skip to content

Commit b71dc21

Browse files
committed
TST used sklearn dense sparse test function
1 parent 6a5ba45 commit b71dc21

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

imblearn/tests/test_misc.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from sklearn.datasets import load_iris
1111
from sklearn.utils.testing import assert_array_equal
12+
from sklearn.utils.testing import assert_allclose_dense_sparse
1213

1314
from imblearn.datasets import make_imbalance
1415
from imblearn.misc import FunctionSampler
@@ -25,10 +26,7 @@
2526
def test_function_sampler_identity(X, y):
2627
sampler = FunctionSampler()
2728
X_res, y_res = sampler.fit_sample(X, y)
28-
if sparse.issparse(X):
29-
X = X.toarray()
30-
X_res = X_res.toarray()
31-
assert_array_equal(X_res, X)
29+
assert_allclose_dense_sparse(X_res, X)
3230
assert_array_equal(y_res, y)
3331

3432

@@ -42,10 +40,7 @@ def func(X, y):
4240

4341
sampler = FunctionSampler(func=func)
4442
X_res, y_res = sampler.fit_sample(X, y)
45-
if sparse.issparse(X):
46-
X = X.toarray()
47-
X_res = X_res.toarray()
48-
assert_array_equal(X_res, X[:10])
43+
assert_allclose_dense_sparse(X_res, X[:10])
4944
assert_array_equal(y_res, y[:10])
5045

5146

@@ -62,9 +57,5 @@ def func(X, y, ratio, random_state):
6257
'random_state': 0})
6358
X_res, y_res = sampler.fit_sample(X, y)
6459
X_res_2, y_res_2 = RandomUnderSampler(random_state=0).fit_sample(X, y)
65-
if sparse.issparse(X):
66-
X = X.toarray()
67-
X_res = X_res.toarray()
68-
X_res_2 = X_res_2.toarray()
69-
assert_array_equal(X_res, X_res_2)
60+
assert_allclose_dense_sparse(X_res, X_res_2)
7061
assert_array_equal(y_res, y_res_2)

0 commit comments

Comments
 (0)