Skip to content

Commit 325a570

Browse files
authored
Merge pull request #169 from datarian/fix-binary-encoder-for-columntransformer
Fix binary encoder for columntransformer
2 parents 5616f6f + 8c71668 commit 325a570

File tree

5 files changed

+34
-6
lines changed

5 files changed

+34
-6
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ env:
1616
# The versions should match the minimal requirements in requirements.txt and setup.py
1717
- DISTRIB="conda" PYTHON_VERSION="2.7" CYTHON_VERSION="0.21"
1818
NUMPY_VERSION="1.11.1" PANDAS_VERSION="0.21.1" PATSY_VERSION="0.4.1"
19-
SCIKIT_VERSION="0.17.1" SCIPY_VERSION="0.17.0" STATSMODELS_VERSION="0.6.1"
19+
SCIKIT_VERSION="0.20.2" SCIPY_VERSION="0.17.0" STATSMODELS_VERSION="0.6.1"
2020
- DISTRIB="conda" PYTHON_VERSION="3.5" COVERAGE="true" CYTHON_VERSION="0.23.4"
2121
NUMPY_VERSION="1.11.1" PANDAS_VERSION="0.21.1" PATSY_VERSION="0.4.1"
2222
SCIKIT_VERSION="0.17.1" SCIPY_VERSION="0.17.0" STATSMODELS_VERSION="0.6.1"

category_encoders/binary.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,16 @@ class BinaryEncoder(BaseEstimator, TransformerMixin):
7171

7272
def __init__(self, verbose=0, cols=None, mapping=None, drop_invariant=False, return_df=True,
7373
handle_unknown='value', handle_missing='value'):
74-
self.base_n_encoder = ce.BaseNEncoder(base=2, verbose=verbose, cols=cols, mapping=mapping,
75-
drop_invariant=drop_invariant, return_df=return_df,
76-
handle_unknown=handle_unknown, handle_missing=handle_missing)
74+
self.verbose = verbose
75+
self.cols = cols
76+
self.mapping = mapping
77+
self.drop_invariant = drop_invariant
78+
self.return_df = return_df
79+
self.handle_unknown = handle_unknown
80+
self.handle_missing = handle_missing
81+
self.base_n_encoder = ce.BaseNEncoder(base=2, verbose=self.verbose, cols=self.cols, mapping=self.mapping,
82+
drop_invariant=self.drop_invariant, return_df=self.return_df,
83+
handle_unknown=self.handle_unknown, handle_missing=self.handle_missing)
7784

7885
def fit(self, X, y=None, **kwargs):
7986
"""Fit encoder according to X and y.

category_encoders/tests/test_encoders.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sklearn
99
import category_encoders.tests.helpers as th
1010
from sklearn.utils.estimator_checks import check_transformer_general, check_transformers_unfitted
11+
from sklearn.compose import ColumnTransformer
1112
from unittest2 import TestSuite, TextTestRunner, TestCase # or `from unittest import ...` if on Python 3.4+
1213

1314
import category_encoders as encoders
@@ -419,3 +420,23 @@ def test_truncated_index(self):
419420
enc2 = getattr(encoders, encoder_name)()
420421
result2 = enc2.fit_transform(data2.x, data2.y)
421422
self.assertTrue((result.values == result2.values).all())
423+
424+
def test_column_transformer(self):
425+
# see issue #169
426+
for encoder_name in (set(encoders.__all__) - {'HashingEncoder'}): # HashingEncoder does not accept handle_missing parameter
427+
with self.subTest(encoder_name=encoder_name):
428+
429+
# we can only test one data type at once. Here, we test string columns.
430+
tested_columns = ['unique_str', 'invariant', 'underscore', 'none', 'extra']
431+
432+
# ColumnTransformer instantiates the encoder twice -> we have to make sure the encoder settings are correctly passed
433+
ct = ColumnTransformer([
434+
("dummy_encoder_name", getattr(encoders, encoder_name)(handle_missing="return_nan"), tested_columns)
435+
])
436+
obtained = ct.fit_transform(X, y)
437+
438+
# the old-school approach
439+
enc = getattr(encoders, encoder_name)(handle_missing="return_nan", return_df=False)
440+
expected = enc.fit_transform(X[tested_columns], y)
441+
442+
np.testing.assert_array_equal(obtained, expected)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy>=1.11.1
2-
scikit-learn>=0.17.1
2+
scikit-learn>=0.20.2
33
scipy>=0.17.0
44
statsmodels>=0.6.1
55
pandas>=0.21.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
author='Will McGinnis',
3131
install_requires=[
3232
'numpy>=1.11.1',
33-
'scikit-learn>=0.17.1',
33+
'scikit-learn>=0.20.2',
3434
'scipy>=0.17.0',
3535
'statsmodels>=0.6.1',
3636
'pandas>=0.21.1',

0 commit comments

Comments
 (0)