Skip to content

Commit a0d4748

Browse files
Merge pull request #373 from nercisla/target_encoding_heirarchical_columnwise
Target encoding heirarchical columnwise
2 parents 81bb01d + 41217ce commit a0d4748

File tree

8 files changed

+300
-20
lines changed

8 files changed

+300
-20
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from ._base import load_compass
2+
from ._base import load_postcodes
3+
4+
__all__ = [
5+
"load_compass",
6+
"load_postcodes",
7+
]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Base IO code for datasets
3+
"""
4+
5+
import pkg_resources
6+
import pandas as pd
7+
8+
def load_compass():
9+
"""Return a dataframe for target encoding with 16 rows of compass directions.
10+
11+
Contains the following fields:
12+
index 16 non-null int64
13+
compass 16 non-null object
14+
HIER_compass_1 16 non-null object
15+
target 16 non-null int64
16+
17+
Returns
18+
-------
19+
X: A pandas series containing features
20+
y: A pandas series containing the target variable
21+
22+
"""
23+
data_filename = "data/compass.csv"
24+
stream = pkg_resources.resource_filename(__name__, data_filename)
25+
26+
with open(stream) as f:
27+
df = pd.read_csv(f, encoding='latin-1')
28+
X = df[['index', 'compass', 'HIER_compass_1']]
29+
y = df['target']
30+
return X, y
31+
32+
33+
def load_postcodes(target_type='binary'):
34+
"""Return a dataframe for target encoding with 100 UK postcodes and hierarchy.
35+
36+
Contains the following fields:
37+
index 100 non-null int64
38+
postcode 100 non-null object
39+
HIER_postcode1 100 non-null object
40+
HIER_postcode2 100 non-null object
41+
HIER_postcode3 100 non-null object
42+
HIER_postcode4 100 non-null object
43+
target_binary 100 non-null int64
44+
target_non_binary 100 non-null int64
45+
target_categorical 100 non-null object
46+
47+
Parameters
48+
----------
49+
target_type : str, default='binary'
50+
Options are 'binary', 'non_binary', 'categorical'
51+
52+
Returns
53+
-------
54+
X: A pandas series containing features
55+
y: A pandas series containing the target variable
56+
57+
"""
58+
data_filename = "data/postcode_dataset_100.csv"
59+
stream = pkg_resources.resource_filename(__name__, data_filename)
60+
61+
with open(stream) as f:
62+
df = pd.read_csv(f, encoding='latin-1')
63+
X = df[df.columns[~df.columns.str.startswith('target')]]
64+
y = df[f'target_{target_type}']
65+
return X, y
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
index,compass,HIER_compass_1,target
2+
1,N,N,1
3+
2,N,N,0
4+
3,NE,N,1
5+
4,NE,N,1
6+
5,NE,N,1
7+
6,SE,S,0
8+
7,SE,S,0
9+
8,S,S,1
10+
9,S,S,0
11+
10,S,S,1
12+
11,S,S,0
13+
12,W,W,1
14+
13,W,W,0
15+
14,W,W,0
16+
15,W,W,0
17+
16,W,W,1
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
index,postcode,HIER_postcode_1,HIER_postcode_2,HIER_postcode_3,HIER_postcode_4,target_binary,target_non_binary,target_categorical
2+
0,S4 2FQ,S,S4,S4 2,S4 2F,1,1,dog
3+
1,S9 2UE,S,S9,S9 2,S9 2U,1,1,dog
4+
2,S4 6UA,S,S4,S4 6,S4 6U,1,1,dog
5+
3,S6 3SA,S,S6,S6 3,S6 3S,1,1,dog
6+
4,S8 3PT,S,S8,S8 3,S8 3P,1,1,dog
7+
5,L1 7FW,L,L1,L1 7,L1 7F,1,1,dog
8+
6,L1 3SF,L,L1,L1 3,L1 3S,1,1,dog
9+
7,L6 8HJ,L,L6,L6 8,L6 8H,1,1,dog
10+
8,L4 2UD,L,L4,L4 2,L4 2U,1,1,dog
11+
9,L4 6PH,L,L4,L4 6,L4 6P,1,1,dog
12+
10,S3 6JE,S,S3,S3 6,S3 6J,1,1,dog
13+
11,L2 3AJ,L,L2,L2 3,L2 3A,1,1,dog
14+
12,L9 5NB,L,L9,L9 5,L9 5N,1,1,dog
15+
13,S4 3DE,S,S4,S4 3,S4 3D,1,1,dog
16+
14,S3 7ZU,S,S3,S3 7,S3 7Z,1,1,dog
17+
15,S3 1WG,S,S3,S3 1,S3 1W,1,1,dog
18+
16,S4 8WP,S,S4,S4 8,S4 8W,1,1,dog
19+
17,L1 7FR,L,L1,L1 7,L1 7F,1,1,dog
20+
18,L8 0AR,L,L8,L8 0,L8 0A,1,2,cat
21+
19,S9 0SQ,S,S9,S9 0,S9 0S,1,2,cat
22+
20,S8 6PW,S,S8,S8 6,S8 6P,1,2,cat
23+
21,L6 7QS,L,L6,L6 7,L6 7Q,1,2,cat
24+
22,L5 1LH,L,L5,L5 1,L5 1L,1,2,cat
25+
23,L6 0SU,L,L6,L6 0,L6 0S,1,2,cat
26+
24,S2 1ZF,S,S2,S2 1,S2 1Z,1,2,cat
27+
25,S3 8AR,S,S3,S3 8,S3 8A,1,2,cat
28+
26,L3 9NU,L,L3,L3 9,L3 9N,1,2,cat
29+
27,S5 4LZ,S,S5,S5 4,S5 4L,1,2,cat
30+
28,L4 2HW,L,L4,L4 2,L4 2H,1,2,cat
31+
29,S9 2ES,S,S9,S9 2,S9 2E,1,2,cat
32+
30,L3 5LZ,L,L3,L3 5,L3 5L,1,3,mouse
33+
31,L5 6QA,L,L5,L5 6,L5 6Q,1,3,mouse
34+
32,S8 6ZB,S,S8,S8 6,S8 6Z,1,3,mouse
35+
33,S4 8RU,S,S4,S4 8,S4 8R,1,3,mouse
36+
34,L7 9NG,L,L7,L7 9,L7 9N,1,3,mouse
37+
35,L5 1QX,L,L5,L5 1,L5 1Q,1,3,mouse
38+
36,S2 7JN,S,S2,S2 7,S2 7J,1,3,mouse
39+
37,S5 3SP,S,S5,S5 3,S5 3S,1,3,mouse
40+
38,S5 1UL,S,S5,S5 1,S5 1U,1,3,mouse
41+
39,L5 6UN,L,L5,L5 6,L5 6U,1,3,mouse
42+
40,S7 9PL,S,S7,S7 9,S7 9P,1,3,mouse
43+
41,S5 1AB,S,S5,S5 1,S5 1A,1,3,mouse
44+
42,S6 3TB,S,S6,S6 3,S6 3T,1,3,mouse
45+
43,S8 2UZ,S,S8,S8 2,S8 2U,1,3,mouse
46+
44,S5 9GU,S,S5,S5 9,S5 9G,1,3,mouse
47+
45,L1 1DN,L,L1,L1 1,L1 1D,1,3,mouse
48+
46,L6 8YZ,L,L6,L6 8,L6 8Y,1,3,mouse
49+
47,S6 6GB,S,S6,S6 6,S6 6G,0,3,mouse
50+
48,L3 7BD,L,L3,L3 7,L3 7B,0,3,mouse
51+
49,L1 2JF,L,L1,L1 2,L1 2J,0,3,mouse
52+
50,L3 8GP,L,L3,L3 8,L3 8G,0,3,mouse
53+
51,S2 4PD,S,S2,S2 4,S2 4P,0,3,mouse
54+
52,L5 2XY,L,L5,L5 2,L5 2X,0,3,mouse
55+
53,L4 4DF,L,L4,L4 4,L4 4D,0,3,mouse
56+
54,S6 0QZ,S,S6,S6 0,S6 0Q,0,3,mouse
57+
55,S9 4DA,S,S9,S9 4,S9 4D,0,3,mouse
58+
56,L4 1RZ,L,L4,L4 1,L4 1R,0,3,mouse
59+
57,L1 1YX,L,L1,L1 1,L1 1Y,0,3,mouse
60+
58,L4 8JF,L,L4,L4 8,L4 8J,0,3,mouse
61+
59,L1 9SY,L,L1,L1 9,L1 9S,0,3,mouse
62+
60,S7 1DS,S,S7,S7 1,S7 1D,0,3,mouse
63+
61,S2 3SB,S,S2,S2 3,S2 3S,0,3,mouse
64+
62,S5 5EY,S,S5,S5 5,S5 5E,0,3,mouse
65+
63,L3 6SP,L,L3,L3 6,L3 6S,0,3,mouse
66+
64,S6 9LE,S,S6,S6 9,S6 9L,0,3,mouse
67+
65,S7 6GE,S,S7,S7 6,S7 6G,0,3,mouse
68+
66,S3 2XQ,S,S3,S3 2,S3 2X,0,4,rabbit
69+
67,L6 7RD,L,L6,L6 7,L6 7R,0,4,rabbit
70+
68,L4 5TB,L,L4,L4 5,L4 5T,0,4,rabbit
71+
69,S1 9ZY,S,S1,S1 9,S1 9Z,0,4,rabbit
72+
70,L8 3QT,L,L8,L8 3,L8 3Q,0,4,rabbit
73+
71,S8 1SB,S,S8,S8 1,S8 1S,0,4,rabbit
74+
72,L8 8PD,L,L8,L8 8,L8 8P,0,4,rabbit
75+
73,S8 0YX,S,S8,S8 0,S8 0Y,0,4,rabbit
76+
74,S4 9QH,S,S4,S4 9,S4 9Q,0,4,rabbit
77+
75,S6 4XJ,S,S6,S6 4,S6 4X,0,4,rabbit
78+
76,L8 6YG,L,L8,L8 6,L8 6Y,0,4,rabbit
79+
77,L7 7SP,L,L7,L7 7,L7 7S,0,4,rabbit
80+
78,L5 6TW,L,L5,L5 6,L5 6T,0,4,rabbit
81+
79,S5 7YX,S,S5,S5 7,S5 7Y,0,4,rabbit
82+
80,L1 4HG,L,L1,L1 4,L1 4H,0,5,hamster
83+
81,L6 7DB,L,L6,L6 7,L6 7D,0,5,hamster
84+
82,S6 2UA,S,S6,S6 2,S6 2U,0,5,hamster
85+
83,L9 4PJ,L,L9,L9 4,L9 4P,0,5,hamster
86+
84,L4 5DF,L,L4,L4 5,L4 5D,0,5,hamster
87+
85,L7 7UY,L,L7,L7 7,L7 7U,0,5,hamster
88+
86,S1 0FX,S,S1,S1 0,S1 0F,0,5,hamster
89+
87,S7 5RY,S,S7,S7 5,S7 5R,0,5,hamster
90+
88,S8 1YS,S,S8,S8 1,S8 1Y,0,5,hamster
91+
89,S4 2HB,S,S4,S4 2,S4 2H,0,5,hamster
92+
90,S6 6AZ,S,S6,S6 6,S6 6A,0,5,hamster
93+
91,L8 5YG,L,L8,L8 5,L8 5Y,0,5,hamster
94+
92,L1 5JW,L,L1,L1 5,L1 5J,0,5,hamster
95+
93,S5 8NP,S,S5,S5 8,S5 8N,0,5,hamster
96+
94,S5 5RS,S,S5,S5 5,S5 5R,0,5,hamster
97+
95,L1 9SZ,L,L1,L1 9,L1 9S,0,5,hamster
98+
96,L2 7ZH,L,L2,L2 7,L2 7Z,0,5,hamster
99+
97,L2 4RR,L,L2,L2 4,L2 4R,0,5,hamster
100+
98,S8 3EP,S,S8,S8 3,S8 3E,0,5,hamster
101+
99,L4 6ND,L,L4,L4 6,L4 6N,0,5,hamster

category_encoders/target_encoder.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,17 @@ class TargetEncoder(util.BaseEncoder, util.SupervisedTransformerMixin):
4141
smoothing: float
4242
smoothing effect to balance categorical average vs prior. Higher value means stronger regularization.
4343
The value must be strictly bigger than 0. Higher values mean a flatter S-curve (see min_samples_leaf).
44-
hierarchy: dict
45-
a dictionary of columns to map into hierarchies. Dictionary key(s) should be the column name from X
44+
hierarchy: dict or dataframe
45+
A dictionary or a dataframe to define the hierarchy for mapping.
46+
47+
If a dictionary, this contains a dict of columns to map into hierarchies. Dictionary key(s) should be the column name from X
4648
which requires mapping. For multiple hierarchical maps, this should be a dictionary of dictionaries.
49+
50+
If dataframe: a dataframe defining columns to be used for the hierarchies. Column names must take the form:
51+
HIER_colA_1, ... HIER_colA_N, HIER_colB_1, ... HIER_colB_M, ...
52+
where [colA, colB, ...] are given columns in cols list.
53+
1:N and 1:M define the hierarchy for each column where 1 is the highest hierarchy (top of the tree). A single column or multiple
54+
can be used, as relevant.
4755
4856
Examples
4957
-------
@@ -75,16 +83,24 @@ class TargetEncoder(util.BaseEncoder, util.SupervisedTransformerMixin):
7583
dtypes: float64(13)
7684
memory usage: 51.5 KB
7785
None
78-
79-
>>> X = ['N', 'N', 'NE', 'NE', 'NE', 'SE', 'SE', 'S', 'S', 'S', 'S', 'W', 'W', 'W', 'W', 'W']
80-
>>> hierarchical_map = {'Compass': {'N': ('N', 'NE'), 'S': ('S', 'SE'), 'W': 'W'}}
81-
>>> y = [1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
82-
>>> enc = TargetEncoder(hierarchy=hierarchical_map).fit(X, y)
83-
>>> hierarchy_dataset = enc.transform(X)
84-
>>> print(hierarchy_dataset[0].values)
85-
[0.5 0.5 0.94039854 0.94039854 0.94039854 0.13447071
86-
0.13447071 0.5 0.5 0.5 0.5 0.40179862
87-
0.40179862 0.40179862 0.40179862 0.40179862]
86+
87+
>>> from category_encoders.datasets import load_compass
88+
>>> X, y = load_compass()
89+
>>> hierarchical_map = {'compass': {'N': ('N', 'NE'), 'S': ('S', 'SE'), 'W': 'W'}}
90+
>>> enc = TargetEncoder(verbose=1, smoothing=2, min_samples_leaf=2, hierarchy=hierarchical_map, cols=['compass']).fit(X.loc[:,['compass']], y)
91+
>>> hierarchy_dataset = enc.transform(X.loc[:,['compass']])
92+
>>> print(hierarchy_dataset['compass'].values)
93+
[0.62263617 0.62263617 0.90382995 0.90382995 0.90382995 0.17660024
94+
0.17660024 0.46051953 0.46051953 0.46051953 0.46051953 0.40332791
95+
0.40332791 0.40332791 0.40332791 0.40332791]
96+
>>> X, y = load_postcodes('binary')
97+
>>> cols = ['postcode']
98+
>>> HIER_cols = ['HIER_postcode_1','HIER_postcode_2','HIER_postcode_3','HIER_postcode_4']
99+
>>> enc = TargetEncoder(verbose=1, smoothing=2, min_samples_leaf=2, hierarchy=X[HIER_cols], cols=['postcode']).fit(X['postcode'], y)
100+
>>> hierarchy_dataset = enc.transform(X['postcode'])
101+
>>> print(hierarchy_dataset.loc[0:10, 'postcode'].values)
102+
[0.75063473 0.90208756 0.88328833 0.77041254 0.68891504 0.85012847
103+
0.76772574 0.88742357 0.7933824 0.63776756 0.9019973 ]
88104
89105
References
90106
----------
@@ -113,19 +129,33 @@ def __init__(self, verbose=0, cols=None, drop_invariant=False, return_df=True, h
113129
category=FutureWarning)
114130
self.mapping = None
115131
self._mean = None
116-
if hierarchy:
132+
if isinstance(hierarchy, (dict, pd.DataFrame)) and cols is None:
133+
raise ValueError('Hierarchy is defined but no columns are named for encoding')
134+
if isinstance(hierarchy, dict):
117135
self.hierarchy = {}
118136
self.hierarchy_depth = {}
119137
for switch in hierarchy:
120138
flattened_hierarchy = util.flatten_reverse_dict(hierarchy[switch])
121-
122139
hierarchy_check = self._check_dict_key_tuples(flattened_hierarchy)
123140
self.hierarchy_depth[switch] = hierarchy_check[1]
124141
if not hierarchy_check[0]:
125142
raise ValueError('Hierarchy mapping contains different levels for key "' + switch + '"')
126143
self.hierarchy[switch] = {(k if isinstance(t, tuple) else t): v for t, v in flattened_hierarchy.items() for k in t}
127-
else:
144+
elif isinstance(hierarchy, pd.DataFrame):
128145
self.hierarchy = hierarchy
146+
self.hierarchy_depth = {}
147+
for col in self.cols:
148+
HIER_cols = self.hierarchy.columns[self.hierarchy.columns.str.startswith(f'HIER_{col}')].values
149+
HIER_levels = [int(i.replace(f'HIER_{col}_', '')) for i in HIER_cols]
150+
if np.array_equal(sorted(HIER_levels), np.arange(1, max(HIER_levels)+1)):
151+
self.hierarchy_depth[col] = max(HIER_levels)
152+
else:
153+
raise ValueError(f'Hierarchy columns are not complete for column {col}')
154+
elif hierarchy is None:
155+
self.hierarchy = hierarchy
156+
else:
157+
raise ValueError('Given hierarchy mapping is neither a dictionary nor a dataframe')
158+
129159
self.cols_hier = []
130160

131161
def _check_dict_key_tuples(self, d):
@@ -134,14 +164,17 @@ def _check_dict_key_tuples(self, d):
134164
return min_tuple_size == max_tuple_size, min_tuple_size
135165

136166
def _fit(self, X, y, **kwargs):
137-
if self.hierarchy:
167+
if isinstance(self.hierarchy, dict):
138168
X_hier = pd.DataFrame()
139169
for switch in self.hierarchy:
140170
if switch in self.cols:
141171
colnames = [f'HIER_{str(switch)}_{str(i + 1)}' for i in range(self.hierarchy_depth[switch])]
142172
df = pd.DataFrame(X[str(switch)].map(self.hierarchy[str(switch)]).tolist(), index=X.index, columns=colnames)
143173
X_hier = pd.concat([X_hier, df], axis=1)
174+
elif isinstance(self.hierarchy, pd.DataFrame):
175+
X_hier = self.hierarchy
144176

177+
if isinstance(self.hierarchy, (dict, pd.DataFrame)):
145178
enc_hier = OrdinalEncoder(
146179
verbose=self.verbose,
147180
cols=X_hier.columns,
@@ -159,7 +192,7 @@ def _fit(self, X, y, **kwargs):
159192
)
160193
self.ordinal_encoder = self.ordinal_encoder.fit(X)
161194
X_ordinal = self.ordinal_encoder.transform(X)
162-
if self.hierarchy:
195+
if self.hierarchy is not None:
163196
self.mapping = self.fit_target_encoding(pd.concat([X_ordinal, X_hier_ordinal], axis=1), y)
164197
else:
165198
self.mapping = self.fit_target_encoding(X_ordinal, y)
@@ -174,7 +207,8 @@ def fit_target_encoding(self, X, y):
174207
values = switch.get('mapping')
175208

176209
scalar = prior
177-
if self.hierarchy and col in self.hierarchy:
210+
if (isinstance(self.hierarchy, dict) and col in self.hierarchy) or \
211+
(isinstance(self.hierarchy, pd.DataFrame)):
178212
for i in range(self.hierarchy_depth[col]):
179213
col_hier = 'HIER_'+str(col)+'_'+str(i+1)
180214
col_hier_m1 = col if i == self.hierarchy_depth[col]-1 else 'HIER_'+str(col)+'_'+str(i+2)

category_encoders/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,9 @@ def fit(self, X, y=None, **kwargs):
295295
self._dim = X.shape[1]
296296
self._get_fit_columns(X)
297297

298+
if not set(self.cols).issubset(X.columns):
299+
raise ValueError('X does not contain the columns listed in cols')
300+
298301
if self.handle_missing == 'error':
299302
if X[self.cols].isnull().any().any():
300303
raise ValueError('Columns to be encoded can not contain null')

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
'pandas>=1.0.5',
3737
'patsy>=0.5.1',
3838
],
39-
author_email='[email protected]'
39+
author_email='[email protected]',
40+
package_data={'': ['datasets/data/*.csv']},
4041
)

0 commit comments

Comments
 (0)