Skip to content

Commit 9cfdf01

Browse files
committed
Added a test test_hierarchy_with_scikit_learn_column_transformer to demonstrate the failure detailed in issue 460.
#460
1 parent 9a86233 commit 9cfdf01

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/test_target_encoder.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from category_encoders.datasets import load_compass, load_postcodes
8+
from sklearn.compose import ColumnTransformer
89

910
import tests.helpers as th
1011

@@ -505,3 +506,37 @@ def test_hierarchy_mapping_cols_missing(self):
505506
)
506507
with self.assertRaises(ValueError):
507508
enc.fit_transform(X, y)
509+
510+
def test_hierarchy_with_scikit_learn_column_transformer(self):
511+
"""Test that the encoder works with a scikit-learn ColumnTransformer."""
512+
features: list[str] = ["cat_feature", "int_feature"]
513+
target: str = "target"
514+
515+
df: pd.DataFrame = pd.DataFrame({
516+
"cat_feature": ["aa", "ab", "ac", "ba", "bb", "bc"],
517+
"int_feature": [10, 11, 9, 8, 12, 10],
518+
"target": [1, 2, 1, 2, 1, 2]
519+
})
520+
hierarchical_map: dict = {
521+
"cat_feature": {
522+
"a": ("aa", "ab", "ac"),
523+
"b": ("ba", "bb", "bc"),
524+
},
525+
}
526+
527+
# Get the encoder values from a simple TargetEncoder
528+
target_encoder = encoders.TargetEncoder(cols=["cat_feature"], min_samples_leaf=2, smoothing=2, hierarchy=hierarchical_map)
529+
encoder_values = target_encoder.fit_transform(df[features], df[target])["cat_feature"].values
530+
531+
# Get the encoder values from a ColumnTransformer
532+
preprocessor = ColumnTransformer(
533+
transformers=[
534+
("categorical", encoders.TargetEncoder(cols=["cat_feature"], min_samples_leaf=2, smoothing=2, hierarchy=hierarchical_map), ["cat_feature"]),
535+
],
536+
remainder='passthrough',
537+
verbose_feature_names_out=False,
538+
).set_output(transform="pandas")
539+
columntransformer_encoder_values = preprocessor.fit_transform(df[features], df[target])["cat_feature"].values
540+
541+
# Compare the results
542+
np.testing.assert_array_almost_equal(encoder_values, columntransformer_encoder_values, decimal=4)

0 commit comments

Comments
 (0)