Skip to content

Commit 370e4b8

Browse files
Merge pull request #461 from dennisobrien/issue-460-fix_targetencoder
Fix TargetEncoder compatibility with scikit-learn ColumnTransformer
2 parents 0289cab + 1e65143 commit 370e4b8

File tree

2 files changed

+86
-45
lines changed

2 files changed

+86
-45
lines changed

category_encoders/target_encoder.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -159,48 +159,14 @@ def __init__(
159159
handle_unknown=handle_unknown,
160160
handle_missing=handle_missing,
161161
)
162-
self.ordinal_encoder = None
163162
self.min_samples_leaf = min_samples_leaf
164163
self.smoothing = smoothing
164+
self.hierarchy = hierarchy
165+
self.ordinal_encoder = None
165166
self.mapping = None
166167
self._mean = None
167-
# @ToDo create a function to check the hierarchy
168-
if isinstance(hierarchy, (dict, pd.DataFrame)) and cols is None:
169-
raise ValueError('Hierarchy is defined but no columns are named for encoding')
170-
if isinstance(hierarchy, dict):
171-
self.hierarchy = {}
172-
self.hierarchy_depth = {}
173-
for switch in hierarchy:
174-
flattened_hierarchy = util.flatten_reverse_dict(hierarchy[switch])
175-
hierarchy_check = self._check_dict_key_tuples(flattened_hierarchy)
176-
self.hierarchy_depth[switch] = hierarchy_check[1]
177-
if not hierarchy_check[0]:
178-
raise ValueError(
179-
'Hierarchy mapping contains different levels for key "' + switch + '"'
180-
)
181-
self.hierarchy[switch] = {
182-
(k if isinstance(t, tuple) else t): v
183-
for t, v in flattened_hierarchy.items()
184-
for k in t
185-
}
186-
elif isinstance(hierarchy, pd.DataFrame):
187-
self.hierarchy = hierarchy
188-
self.hierarchy_depth = {}
189-
for col in self.cols:
190-
HIER_cols = self.hierarchy.columns[
191-
self.hierarchy.columns.str.startswith(f'HIER_{col}')
192-
].tolist()
193-
HIER_levels = [int(i.replace(f'HIER_{col}_', '')) for i in HIER_cols]
194-
if np.array_equal(sorted(HIER_levels), np.arange(1, max(HIER_levels) + 1)):
195-
self.hierarchy_depth[col] = max(HIER_levels)
196-
else:
197-
raise ValueError(f'Hierarchy columns are not complete for column {col}')
198-
elif hierarchy is None:
199-
self.hierarchy = hierarchy
200-
else:
201-
raise ValueError('Given hierarchy mapping is neither a dictionary nor a dataframe')
202-
203-
self.cols_hier = []
168+
# Call this in the constructor only for the possible side effect of raising an exception.
169+
self._generate_inverted_hierarchy()
204170

205171
@staticmethod
206172
def _check_dict_key_tuples(dict_to_check: dict[Any, tuple]) -> tuple[bool, int]:
@@ -219,24 +185,25 @@ def _check_dict_key_tuples(dict_to_check: dict[Any, tuple]) -> tuple[bool, int]:
219185
return min_tuple_size == max_tuple_size, min_tuple_size
220186

221187
def _fit(self, X: util.X_type, y: util.y_type, **kwargs) -> None:
222-
if isinstance(self.hierarchy, dict):
188+
inverted_hierarchy, self.hierarchy_depth = self._generate_inverted_hierarchy()
189+
if isinstance(inverted_hierarchy, dict):
223190
X_hier = pd.DataFrame()
224-
for switch in self.hierarchy:
191+
for switch in inverted_hierarchy:
225192
if switch in self.cols:
226193
colnames = [
227194
f'HIER_{str(switch)}_{str(i + 1)}'
228195
for i in range(self.hierarchy_depth[switch])
229196
]
230197
df = pd.DataFrame(
231-
X[str(switch)].map(self.hierarchy[str(switch)]).tolist(),
198+
X[str(switch)].map(inverted_hierarchy[str(switch)]).tolist(),
232199
index=X.index,
233200
columns=colnames,
234201
)
235202
X_hier = pd.concat([X_hier, df], axis=1)
236-
elif isinstance(self.hierarchy, pd.DataFrame):
237-
X_hier = self.hierarchy
203+
elif isinstance(inverted_hierarchy, pd.DataFrame):
204+
X_hier = inverted_hierarchy
238205

239-
if isinstance(self.hierarchy, (dict, pd.DataFrame)):
206+
if isinstance(inverted_hierarchy, (dict, pd.DataFrame)):
240207
enc_hier = OrdinalEncoder(
241208
verbose=self.verbose,
242209
cols=X_hier.columns,
@@ -251,7 +218,7 @@ def _fit(self, X: util.X_type, y: util.y_type, **kwargs) -> None:
251218
)
252219
self.ordinal_encoder = self.ordinal_encoder.fit(X)
253220
X_ordinal = self.ordinal_encoder.transform(X)
254-
if self.hierarchy is not None:
221+
if inverted_hierarchy is not None:
255222
self.mapping = self.fit_target_encoding(
256223
pd.concat([X_ordinal, X_hier_ordinal], axis=1), y
257224
)
@@ -344,3 +311,42 @@ def _weighting(self, n: int) -> float:
344311
# monotonically increasing function of n bounded between 0 and 1
345312
# sigmoid in this case, using scipy.expit for numerical stability
346313
return expit((n - self.min_samples_leaf) / self.smoothing)
314+
315+
def _generate_inverted_hierarchy(self) -> tuple[dict | pd.DataFrame | None, dict]:
316+
if isinstance(self.hierarchy, (dict, pd.DataFrame)) and self.cols is None:
317+
raise ValueError('Hierarchy is defined but no columns are named for encoding')
318+
if isinstance(self.hierarchy, dict):
319+
inverted_hierarchy = {}
320+
hierarchy_depth = {}
321+
for switch in self.hierarchy:
322+
flattened_hierarchy = util.flatten_reverse_dict(self.hierarchy[switch])
323+
hierarchy_check = self._check_dict_key_tuples(flattened_hierarchy)
324+
hierarchy_depth[switch] = hierarchy_check[1]
325+
if not hierarchy_check[0]:
326+
raise ValueError(
327+
'Hierarchy mapping contains different levels for key "' + switch + '"'
328+
)
329+
inverted_hierarchy[switch] = {
330+
(k if isinstance(t, tuple) else t): v
331+
for t, v in flattened_hierarchy.items()
332+
for k in t
333+
}
334+
elif isinstance(self.hierarchy, pd.DataFrame):
335+
inverted_hierarchy = self.hierarchy
336+
hierarchy_depth = {}
337+
for col in self.cols:
338+
HIER_cols = inverted_hierarchy.columns[
339+
inverted_hierarchy.columns.str.startswith(f'HIER_{col}')
340+
].tolist()
341+
HIER_levels = [int(i.replace(f'HIER_{col}_', '')) for i in HIER_cols]
342+
if np.array_equal(sorted(HIER_levels), np.arange(1, max(HIER_levels) + 1)):
343+
hierarchy_depth[col] = max(HIER_levels)
344+
else:
345+
raise ValueError(f'Hierarchy columns are not complete for column {col}')
346+
elif self.hierarchy is None:
347+
inverted_hierarchy = None
348+
hierarchy_depth = {}
349+
else:
350+
raise ValueError('Given hierarchy mapping is neither a dictionary nor a dataframe')
351+
return inverted_hierarchy, hierarchy_depth
352+

tests/test_target_encoder.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from category_encoders.datasets import load_compass, load_postcodes
8+
from sklearn.compose import ColumnTransformer
89

910
import tests.helpers as th
1011

@@ -505,3 +506,37 @@ def test_hierarchy_mapping_cols_missing(self):
505506
)
506507
with self.assertRaises(ValueError):
507508
enc.fit_transform(X, y)
509+
510+
def test_hierarchy_with_scikit_learn_column_transformer(self) -> None:
511+
"""Test that the encoder works with a scikit-learn ColumnTransformer."""
512+
features: list[str] = ["cat_feature", "int_feature"]
513+
target: str = "target"
514+
515+
df: pd.DataFrame = pd.DataFrame({
516+
"cat_feature": ["aa", "ab", "ac", "ba", "bb", "bc"],
517+
"int_feature": [10, 11, 9, 8, 12, 10],
518+
"target": [1, 2, 1, 2, 1, 2]
519+
})
520+
hierarchical_map: dict = {
521+
"cat_feature": {
522+
"a": ("aa", "ab", "ac"),
523+
"b": ("ba", "bb", "bc"),
524+
},
525+
}
526+
527+
# Get the encoder values from a simple TargetEncoder
528+
target_encoder = encoders.TargetEncoder(cols=["cat_feature"], min_samples_leaf=2, smoothing=2, hierarchy=hierarchical_map)
529+
encoder_values = target_encoder.fit_transform(df[features], df[target])["cat_feature"].values
530+
531+
# Get the encoder values from a ColumnTransformer
532+
preprocessor = ColumnTransformer(
533+
transformers=[
534+
("categorical", encoders.TargetEncoder(cols=["cat_feature"], min_samples_leaf=2, smoothing=2, hierarchy=hierarchical_map), ["cat_feature"]),
535+
],
536+
remainder='passthrough',
537+
verbose_feature_names_out=False,
538+
).set_output(transform="pandas")
539+
columntransformer_encoder_values = preprocessor.fit_transform(df[features], df[target])["cat_feature"].values
540+
541+
# Compare the results
542+
np.testing.assert_array_almost_equal(encoder_values, columntransformer_encoder_values, decimal=4)

0 commit comments

Comments
 (0)