Skip to content

Commit 072217f

Browse files
committed
Merge branch 'target_encoding_heirarchical_columnwise' of https://github.com/nercisla/category_encoders into target_encoding_heirarchical_columnwise
2 parents dae530d + b77cd40 commit 072217f

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

category_encoders/leave_one_out.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def fit_column_map(self, series, y):
128128
codes[codes == -1] = len(categories)
129129
categories = np.append(categories, np.nan)
130130

131-
return_map = pd.Series(dict([(code, category) for code, category in enumerate(categories)]))
131+
return_map = pd.Series({code: category for code, category in enumerate(categories)})
132132

133133
result = y.groupby(codes).agg(['sum', 'count'])
134134
return result.rename(return_map)

category_encoders/target_encoder.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,7 @@ def __init__(self, verbose=0, cols=None, drop_invariant=False, return_df=True, h
152152
def _check_dict_key_tuples(self, d):
153153
min_tuple_size = min(len(v) for v in d.values())
154154
max_tuple_size = max(len(v) for v in d.values())
155-
if min_tuple_size == max_tuple_size:
156-
return True, min_tuple_size
157-
else:
158-
return False, min_tuple_size
155+
return min_tuple_size == max_tuple_size, min_tuple_size
159156

160157
def _fit(self, X, y, **kwargs):
161158
if isinstance(self.hierarchy, dict):

tests/test_target_encoder.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,20 @@ def test_hierarchy_error(self):
258258
encoders.TargetEncoder(verbose=1, smoothing=2, min_samples_leaf=2, hierarchy=hierarchical_map,
259259
cols=['Plant'])
260260

261+
def test_trivial_hierarchy(self):
262+
trivial_hierarchical_map = {
263+
'Plant': {
264+
'Plant': ('Rose', 'Daisy', 'Daffodil', 'Bluebell')
265+
}
266+
}
267+
268+
enc_hier = encoders.TargetEncoder(verbose=1, smoothing=2, min_samples_leaf=2, hierarchy=trivial_hierarchical_map,
269+
cols=['Plant'])
270+
result_hier = enc_hier.fit_transform(self.hierarchical_cat_example, self.hierarchical_cat_example['target'])
271+
enc_no_hier = encoders.TargetEncoder(verbose=1, smoothing=2, min_samples_leaf=2, cols=['Plant'])
272+
result_no_hier = enc_no_hier.fit_transform(self.hierarchical_cat_example, self.hierarchical_cat_example['target'])
273+
pd.testing.assert_series_equal(result_hier["Plant"], result_no_hier["Plant"])
274+
261275
def test_hierarchy_multi_level(self):
262276
hierarchy_multi_level_df = pd.DataFrame(
263277
{
@@ -291,7 +305,6 @@ def test_hierarchy_multi_level(self):
291305
self.assertAlmostEqual(0.2466, values[13], delta=1e-4)
292306
self.assertAlmostEqual(0.4741, values[14], delta=1e-4)
293307

294-
295308
def test_hierarchy_columnwise_compass(self):
296309
X, y = load_compass()
297310
cols = X.columns[~X.columns.str.startswith('HIER')]
@@ -341,4 +354,4 @@ def test_hierarchy_mapping_cols_missing(self):
341354
enc = encoders.TargetEncoder(verbose=1, smoothing=2, min_samples_leaf=2, hierarchy=hierarchical_map,
342355
cols=['Compass'])
343356
with self.assertRaises(ValueError):
344-
enc.fit_transform(X, y)
357+
enc.fit_transform(X, y)

0 commit comments

Comments
 (0)