Skip to content

Commit 9e976a4

Browse files
ts2095timschulzglemaitre
authored
ENH preserve dtype and type when providing a dataframe with sparse dtype (#1054)
Co-authored-by: timschulz <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent c7a1838 commit 9e976a4

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

doc/whats_new/v0.12.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,9 @@ Deprecations
4242
- Deprecate `kind_sel` in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule.
4343
It will be removed in 0.14. The parameter does not have any effect.
4444
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.
45+
46+
Enhancements
47+
............
48+
49+
- Allows to output dataframe with sparse format if provided as input.
50+
:pr:`1059` by :user:`ts2095 <ts2095>`.

imblearn/utils/_validation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from numbers import Integral, Real
1111

1212
import numpy as np
13+
from scipy.sparse import issparse
1314
from sklearn.base import clone
1415
from sklearn.neighbors import NearestNeighbors
1516
from sklearn.utils import check_array, column_or_1d
@@ -61,7 +62,10 @@ def _transfrom_one(self, array, props):
6162
elif type_ == "dataframe":
6263
import pandas as pd
6364

64-
ret = pd.DataFrame(array, columns=props["columns"])
65+
if issparse(array):
66+
ret = pd.DataFrame.sparse.from_spmatrix(array, columns=props["columns"])
67+
else:
68+
ret = pd.DataFrame(array, columns=props["columns"])
6569
ret = ret.astype(props["dtypes"])
6670
elif type_ == "series":
6771
import pandas as pd

imblearn/utils/estimator_checks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _yield_sampler_checks(sampler):
108108
yield check_samplers_sparse
109109
if "dataframe" in tags["X_types"]:
110110
yield check_samplers_pandas
111+
yield check_samplers_pandas_sparse
111112
if "string" in tags["X_types"]:
112113
yield check_samplers_string
113114
if tags["allow_nan"]:
@@ -312,6 +313,34 @@ def check_samplers_sparse(name, sampler_orig):
312313
assert_allclose(y_res_sparse, y_res)
313314

314315

316+
def check_samplers_pandas_sparse(name, sampler_orig):
317+
pd = pytest.importorskip("pandas")
318+
sampler = clone(sampler_orig)
319+
# Check that the samplers handle pandas dataframe and pandas series
320+
X, y = sample_dataset_generator()
321+
X_df = pd.DataFrame(
322+
X, columns=[str(i) for i in range(X.shape[1])], dtype=pd.SparseDtype(float, 0)
323+
)
324+
y_s = pd.Series(y, name="class")
325+
326+
X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
327+
X_res, y_res = sampler.fit_resample(X, y)
328+
329+
# check that we return the same type for dataframes or series types
330+
assert isinstance(X_res_df, pd.DataFrame)
331+
assert isinstance(y_res_s, pd.Series)
332+
333+
for column_dtype in X_res_df.dtypes:
334+
assert isinstance(column_dtype, pd.SparseDtype)
335+
336+
assert X_df.columns.tolist() == X_res_df.columns.tolist()
337+
assert y_s.name == y_res_s.name
338+
339+
# FIXME: we should use to_numpy with pandas >= 0.25
340+
assert_allclose(X_res_df.values, X_res)
341+
assert_allclose(y_res_s.values, y_res)
342+
343+
315344
def check_samplers_pandas(name, sampler_orig):
316345
pd = pytest.importorskip("pandas")
317346
sampler = clone(sampler_orig)

0 commit comments

Comments
 (0)