|
5 | 5 | import numpy as np |
6 | 6 | import pandas as pd |
7 | 7 | from category_encoders.datasets import load_compass, load_postcodes |
| 8 | +from sklearn.compose import ColumnTransformer |
8 | 9 |
|
9 | 10 | import tests.helpers as th |
10 | 11 |
|
@@ -505,3 +506,37 @@ def test_hierarchy_mapping_cols_missing(self): |
505 | 506 | ) |
506 | 507 | with self.assertRaises(ValueError): |
507 | 508 | enc.fit_transform(X, y) |
| 509 | + |
| 510 | + def test_hierarchy_with_scikit_learn_column_transformer(self): |
| 511 | + """Test that the encoder works with a scikit-learn ColumnTransformer.""" |
| 512 | + features: list[str] = ["cat_feature", "int_feature"] |
| 513 | + target: str = "target" |
| 514 | + |
| 515 | + df: pd.DataFrame = pd.DataFrame({ |
| 516 | + "cat_feature": ["aa", "ab", "ac", "ba", "bb", "bc"], |
| 517 | + "int_feature": [10, 11, 9, 8, 12, 10], |
| 518 | + "target": [1, 2, 1, 2, 1, 2] |
| 519 | + }) |
| 520 | + hierarchical_map: dict = { |
| 521 | + "cat_feature": { |
| 522 | + "a": ("aa", "ab", "ac"), |
| 523 | + "b": ("ba", "bb", "bc"), |
| 524 | + }, |
| 525 | + } |
| 526 | + |
| 527 | + # Get the encoder values from a simple TargetEncoder |
| 528 | + target_encoder = encoders.TargetEncoder(cols=["cat_feature"], min_samples_leaf=2, smoothing=2, hierarchy=hierarchical_map) |
| 529 | + encoder_values = target_encoder.fit_transform(df[features], df[target])["cat_feature"].values |
| 530 | + |
| 531 | + # Get the encoder values from a ColumnTransformer |
| 532 | + preprocessor = ColumnTransformer( |
| 533 | + transformers=[ |
| 534 | + ("categorical", encoders.TargetEncoder(cols=["cat_feature"], min_samples_leaf=2, smoothing=2, hierarchy=hierarchical_map), ["cat_feature"]), |
| 535 | + ], |
| 536 | + remainder='passthrough', |
| 537 | + verbose_feature_names_out=False, |
| 538 | + ).set_output(transform="pandas") |
| 539 | + columntransformer_encoder_values = preprocessor.fit_transform(df[features], df[target])["cat_feature"].values |
| 540 | + |
| 541 | + # Compare the results |
| 542 | + np.testing.assert_array_almost_equal(encoder_values, columntransformer_encoder_values, decimal=4) |
0 commit comments