Skip to content
4 changes: 3 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ def lint(session):

@nox.session
@nox.parametrize('numpy', ['1.18.1', '1.19.4', '1.20.1'])
@nox.parametrize('sklearn', ['0.23.0', '0.24.2', '1.0'])
@nox.parametrize('scipy', ['1.5.4', '1.6.0'])
@nox.parametrize('pandas', ['1.1.4', '1.2.2'])
def tests(session, numpy, scipy, pandas):
def tests(session, numpy, sklearn, scipy, pandas):
session.install('pytest>=5.3.5',
'setuptools>=45.2',
'wheel>=0.34.2',
f'numpy=={numpy}',
f'scikit-learn=={sklearn}',
f'scipy=={scipy}',
f'pandas=={pandas}'
)
Expand Down
12 changes: 7 additions & 5 deletions sklearn_pandas/dataframe_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def _get_feature_names(estimator):
"""
if hasattr(estimator, 'classes_'):
return estimator.classes_
elif hasattr(estimator, 'get_feature_names_out'):
return estimator.get_feature_names_out()
elif hasattr(estimator, 'get_feature_names'):
return estimator.get_feature_names()
return None
Expand Down Expand Up @@ -290,11 +292,11 @@ def get_names(self, columns, transformer, x, alias=None, prefix='',
else:
names = _get_feature_names(transformer)

if names is not None and len(names) == num_cols:
output = [f"{name}_{o}" for o in names]
# otherwise, return name concatenated with '_1', '_2', etc.
else:
output = [name + '_' + str(o) for o in range(num_cols)]
if names is None or len(names) != num_cols:
# return name concatenated with '_0', '_1', etc.
names = range(num_cols)

output = [f"{name}_{o}" for o in names]
else:
output = [name]

Expand Down
18 changes: 18 additions & 0 deletions tests/test_dataframe_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,24 @@ def test_onehot_df():
assert cols[3] == 'target_x0_3'


def test_onehot_2cols_df():
"""
Check level ids from one-hot when mapping 2 columns
"""
df = pd.DataFrame({
'col': [0, 0, 1, 1, 2, 3, 0],
'target': [0, 0, 1, 1, 2, 3, 0]
})
mapper = DataFrameMapper([
(['col', 'target'], OneHotEncoder())
], df_out=True)
transformed = mapper.fit_transform(df)
cols = transformed.columns
assert len(cols) == 8
assert cols[0] == 'col_target_x0_0'
assert cols[4] == 'col_target_x1_0'


def test_customtransform_df():
"""
Check level ids from a transformer in which
Expand Down