Skip to content

Commit 106af76

Browse files
committed
Fixed an issue with TargetEncoder when used with scikit-learn methods that clone the object and validate that the parameters are unchanged.
This moves the validation of hierarchy and the generation of an inverted hiearchy to the helper function _generate_inverted_hierarchy. This function is called in __init__ only to keep compatibility with existing tests by throwing an exception from the constructure rather than waiting for the call to fit. We create a variable inverted_hierarchy rather than overwriting self.hierarchy in order for the clone method to pass validation.
1 parent 5b3a539 commit 106af76

File tree

1 file changed

+52
-45
lines changed

1 file changed

+52
-45
lines changed

category_encoders/target_encoder.py

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -159,48 +159,14 @@ def __init__(
159159
handle_unknown=handle_unknown,
160160
handle_missing=handle_missing,
161161
)
162-
self.ordinal_encoder = None
163162
self.min_samples_leaf = min_samples_leaf
164163
self.smoothing = smoothing
164+
self.hierarchy = hierarchy
165+
self.ordinal_encoder = None
165166
self.mapping = None
166167
self._mean = None
167-
# @ToDo create a function to check the hierarchy
168-
if isinstance(hierarchy, (dict, pd.DataFrame)) and cols is None:
169-
raise ValueError('Hierarchy is defined but no columns are named for encoding')
170-
if isinstance(hierarchy, dict):
171-
self.hierarchy = {}
172-
self.hierarchy_depth = {}
173-
for switch in hierarchy:
174-
flattened_hierarchy = util.flatten_reverse_dict(hierarchy[switch])
175-
hierarchy_check = self._check_dict_key_tuples(flattened_hierarchy)
176-
self.hierarchy_depth[switch] = hierarchy_check[1]
177-
if not hierarchy_check[0]:
178-
raise ValueError(
179-
'Hierarchy mapping contains different levels for key "' + switch + '"'
180-
)
181-
self.hierarchy[switch] = {
182-
(k if isinstance(t, tuple) else t): v
183-
for t, v in flattened_hierarchy.items()
184-
for k in t
185-
}
186-
elif isinstance(hierarchy, pd.DataFrame):
187-
self.hierarchy = hierarchy
188-
self.hierarchy_depth = {}
189-
for col in self.cols:
190-
HIER_cols = self.hierarchy.columns[
191-
self.hierarchy.columns.str.startswith(f'HIER_{col}')
192-
].tolist()
193-
HIER_levels = [int(i.replace(f'HIER_{col}_', '')) for i in HIER_cols]
194-
if np.array_equal(sorted(HIER_levels), np.arange(1, max(HIER_levels) + 1)):
195-
self.hierarchy_depth[col] = max(HIER_levels)
196-
else:
197-
raise ValueError(f'Hierarchy columns are not complete for column {col}')
198-
elif hierarchy is None:
199-
self.hierarchy = hierarchy
200-
else:
201-
raise ValueError('Given hierarchy mapping is neither a dictionary nor a dataframe')
202-
203-
self.cols_hier = []
168+
# Call this in the constructor only for the possible side effect of raising an exception.
169+
self._generate_inverted_hierarchy()
204170

205171
@staticmethod
206172
def _check_dict_key_tuples(dict_to_check: dict[Any, tuple]) -> tuple[bool, int]:
@@ -219,24 +185,25 @@ def _check_dict_key_tuples(dict_to_check: dict[Any, tuple]) -> tuple[bool, int]:
219185
return min_tuple_size == max_tuple_size, min_tuple_size
220186

221187
def _fit(self, X: util.X_type, y: util.y_type, **kwargs) -> None:
222-
if isinstance(self.hierarchy, dict):
188+
inverted_hierarchy, self.hierarchy_depth = self._generate_inverted_hierarchy()
189+
if isinstance(inverted_hierarchy, dict):
223190
X_hier = pd.DataFrame()
224-
for switch in self.hierarchy:
191+
for switch in inverted_hierarchy:
225192
if switch in self.cols:
226193
colnames = [
227194
f'HIER_{str(switch)}_{str(i + 1)}'
228195
for i in range(self.hierarchy_depth[switch])
229196
]
230197
df = pd.DataFrame(
231-
X[str(switch)].map(self.hierarchy[str(switch)]).tolist(),
198+
X[str(switch)].map(inverted_hierarchy[str(switch)]).tolist(),
232199
index=X.index,
233200
columns=colnames,
234201
)
235202
X_hier = pd.concat([X_hier, df], axis=1)
236-
elif isinstance(self.hierarchy, pd.DataFrame):
237-
X_hier = self.hierarchy
203+
elif isinstance(inverted_hierarchy, pd.DataFrame):
204+
X_hier = inverted_hierarchy
238205

239-
if isinstance(self.hierarchy, (dict, pd.DataFrame)):
206+
if isinstance(inverted_hierarchy, (dict, pd.DataFrame)):
240207
enc_hier = OrdinalEncoder(
241208
verbose=self.verbose,
242209
cols=X_hier.columns,
@@ -251,7 +218,7 @@ def _fit(self, X: util.X_type, y: util.y_type, **kwargs) -> None:
251218
)
252219
self.ordinal_encoder = self.ordinal_encoder.fit(X)
253220
X_ordinal = self.ordinal_encoder.transform(X)
254-
if self.hierarchy is not None:
221+
if inverted_hierarchy is not None:
255222
self.mapping = self.fit_target_encoding(
256223
pd.concat([X_ordinal, X_hier_ordinal], axis=1), y
257224
)
@@ -344,3 +311,43 @@ def _weighting(self, n: int) -> float:
344311
# monotonically increasing function of n bounded between 0 and 1
345312
# sigmoid in this case, using scipy.expit for numerical stability
346313
return expit((n - self.min_samples_leaf) / self.smoothing)
314+
315+
def _generate_inverted_hierarchy(self) -> tuple[dict | pd.DataFrame, dict]:
316+
# @ToDo create a function to check the hierarchy
317+
if isinstance(self.hierarchy, (dict, pd.DataFrame)) and self.cols is None:
318+
raise ValueError('Hierarchy is defined but no columns are named for encoding')
319+
if isinstance(self.hierarchy, dict):
320+
inverted_hierarchy = {}
321+
hierarchy_depth = {}
322+
for switch in self.hierarchy:
323+
flattened_hierarchy = util.flatten_reverse_dict(self.hierarchy[switch])
324+
hierarchy_check = self._check_dict_key_tuples(flattened_hierarchy)
325+
hierarchy_depth[switch] = hierarchy_check[1]
326+
if not hierarchy_check[0]:
327+
raise ValueError(
328+
'Hierarchy mapping contains different levels for key "' + switch + '"'
329+
)
330+
inverted_hierarchy[switch] = {
331+
(k if isinstance(t, tuple) else t): v
332+
for t, v in flattened_hierarchy.items()
333+
for k in t
334+
}
335+
elif isinstance(self.hierarchy, pd.DataFrame):
336+
inverted_hierarchy = self.hierarchy
337+
hierarchy_depth = {}
338+
for col in self.cols:
339+
HIER_cols = inverted_hierarchy.columns[
340+
inverted_hierarchy.columns.str.startswith(f'HIER_{col}')
341+
].tolist()
342+
HIER_levels = [int(i.replace(f'HIER_{col}_', '')) for i in HIER_cols]
343+
if np.array_equal(sorted(HIER_levels), np.arange(1, max(HIER_levels) + 1)):
344+
hierarchy_depth[col] = max(HIER_levels)
345+
else:
346+
raise ValueError(f'Hierarchy columns are not complete for column {col}')
347+
elif self.hierarchy is None:
348+
inverted_hierarchy = None
349+
hierarchy_depth = {}
350+
else:
351+
raise ValueError('Given hierarchy mapping is neither a dictionary nor a dataframe')
352+
return inverted_hierarchy, hierarchy_depth
353+

0 commit comments

Comments
 (0)