44import numpy as np
55import pandas as pd
66import category_encoders .utils as util
7+ from category_encoders .ordinal import OrdinalEncoder
78
89from copy import copy
910from sklearn .base import BaseEstimator , TransformerMixin
1011
1112
1213__author__ = 'joshua t. dunn'
1314
14- # COUNT_ENCODER BRANCH
15+
1516class CountEncoder (BaseEstimator , TransformerMixin ):
17+
1618 def __init__ (self , verbose = 0 , cols = None , drop_invariant = False ,
1719 return_df = True , handle_unknown = 'value' ,
1820 handle_missing = 'value' ,
@@ -118,6 +120,7 @@ def __init__(self, verbose=0, cols=None, drop_invariant=False,
118120 self .min_group_name = min_group_name
119121 self .combine_min_nan_groups = combine_min_nan_groups
120122 self .feature_names = None
123+ self .ordinal_encoder = None
121124
122125 self ._check_set_create_attrs ()
123126
@@ -157,9 +160,17 @@ def fit(self, X, y=None, **kwargs):
157160 else :
158161 self .cols = util .convert_cols_to_list (self .cols )
159162
163+ self .ordinal_encoder = OrdinalEncoder (
164+ verbose = self .verbose ,
165+ cols = self .cols ,
166+ handle_unknown = 'value' ,
167+ handle_missing = 'value'
168+ )
169+ self .ordinal_encoder = self .ordinal_encoder .fit (X )
170+ X_ordinal = self .ordinal_encoder .transform (X )
160171 self ._check_set_create_dict_attrs ()
161172
162- self ._fit_count_encode (X , y )
173+ self ._fit_count_encode (X_ordinal , y )
163174
164175 X_temp = self .transform (X , override_return_df = True )
165176 self .feature_names = list (X_temp .columns )
@@ -235,28 +246,11 @@ def _fit_count_encode(self, X_in, y):
235246 self .mapping = {}
236247
237248 for col in self .cols :
238- if X [col ].isnull ().any ():
239- if self ._handle_missing [col ] == 'error' :
240- raise ValueError (
241- 'Missing data found in column %s at fit time.'
242- % (col ,)
243- )
244-
245- elif self ._handle_missing [col ] not in ['value' , 'return_nan' , 'error' , None ]:
246- raise ValueError (
247- '%s key in `handle_missing` should be one of: '
248- ' `value`, `return_nan` and `error` not `%s`.'
249- % (col , str (self ._handle_missing [col ]))
250- )
251-
252- self .mapping [col ] = X [col ].value_counts (
253- normalize = self ._normalize [col ],
254- dropna = False
255- )
256-
257- self .mapping [col ].index = self .mapping [col ].index .astype (object )
258-
259-
249+ mapping_values = X [col ].value_counts (normalize = self ._normalize [col ])
250+ ordinal_encoding = [m ["mapping" ] for m in self .ordinal_encoder .mapping if m ["col" ] == col ][0 ]
251+ reversed_ordinal_enc = {v : k for k , v in ordinal_encoding .to_dict ().items ()}
252+ mapping_values .index = mapping_values .index .map (reversed_ordinal_enc )
253+ self .mapping [col ] = mapping_values
260254
261255 if self ._handle_missing [col ] == 'return_nan' :
262256 self .mapping [col ][np .NaN ] = np .NaN
@@ -272,15 +266,15 @@ def _transform_count_encode(self, X_in, y):
272266 X = X_in .copy (deep = True )
273267
274268 for col in self .cols :
275-
276- X [col ] = X .fillna (value = np .nan )[col ]
269+ # Treat None as np.nan
270+ X [col ] = pd .Series ([el if el is not None else np .NaN for el in X [col ]], index = X [col ].index )
271+ if self .handle_missing == "value" :
272+ if not util .is_category (X [col ].dtype ):
273+ X [col ] = X [col ].fillna (np .nan )
277274
278275 if self ._min_group_size is not None :
279276 if col in self ._min_group_categories .keys ():
280- X [col ] = (
281- X [col ].map (self ._min_group_categories [col ])
282- .fillna (X [col ])
283- )
277+ X [col ] = X [col ].map (self ._min_group_categories [col ]).fillna (X [col ])
284278
285279 X [col ] = X [col ].astype (object ).map (self .mapping [col ])
286280 if isinstance (self ._handle_unknown [col ], (int , np .integer )):
0 commit comments