Skip to content

Commit 6a5ba45

Browse files
committed
TST add test for sparse matrices
1 parent 80f3f24 commit 6a5ba45

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

imblearn/tests/test_misc.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
# Authors: Guillaume Lemaitre <[email protected]>
44
# License: MIT
55

6+
import pytest
7+
8+
from scipy import sparse
9+
610
from sklearn.datasets import load_iris
711
from sklearn.utils.testing import assert_array_equal
812

@@ -15,25 +19,40 @@
1519
random_state=0)
1620

1721

18-
def test_function_sampler_identity():
22+
@pytest.mark.parametrize("X,y", [(X, y),
23+
(sparse.csr_matrix(X), y),
24+
(sparse.csc_matrix(X), y)])
25+
def test_function_sampler_identity(X, y):
1926
sampler = FunctionSampler()
2027
X_res, y_res = sampler.fit_sample(X, y)
28+
if sparse.issparse(X):
29+
X = X.toarray()
30+
X_res = X_res.toarray()
2131
assert_array_equal(X_res, X)
2232
assert_array_equal(y_res, y)
2333

2434

25-
def test_function_sampler_func():
35+
@pytest.mark.parametrize("X,y", [(X, y),
36+
(sparse.csr_matrix(X), y),
37+
(sparse.csc_matrix(X), y)])
38+
def test_function_sampler_func(X, y):
2639

2740
def func(X, y):
2841
return X[:10], y[:10]
2942

3043
sampler = FunctionSampler(func=func)
3144
X_res, y_res = sampler.fit_sample(X, y)
45+
if sparse.issparse(X):
46+
X = X.toarray()
47+
X_res = X_res.toarray()
3248
assert_array_equal(X_res, X[:10])
3349
assert_array_equal(y_res, y[:10])
3450

3551

36-
def test_function_sampler_func_kwargs():
52+
@pytest.mark.parametrize("X,y", [(X, y),
53+
(sparse.csr_matrix(X), y),
54+
(sparse.csc_matrix(X), y)])
55+
def test_function_sampler_func_kwargs(X, y):
3756

3857
def func(X, y, ratio, random_state):
3958
rus = RandomUnderSampler(ratio=ratio, random_state=random_state)
@@ -43,5 +62,9 @@ def func(X, y, ratio, random_state):
4362
'random_state': 0})
4463
X_res, y_res = sampler.fit_sample(X, y)
4564
X_res_2, y_res_2 = RandomUnderSampler(random_state=0).fit_sample(X, y)
65+
if sparse.issparse(X):
66+
X = X.toarray()
67+
X_res = X_res.toarray()
68+
X_res_2 = X_res_2.toarray()
4669
assert_array_equal(X_res, X_res_2)
4770
assert_array_equal(y_res, y_res_2)

0 commit comments

Comments
 (0)