44import numpy as np
55import pandas as pd
66import category_encoders .utils as util
7+ from category_encoders .ordinal import OrdinalEncoder
78
89from copy import copy
910from sklearn .base import BaseEstimator , TransformerMixin
1011
1112
1213__author__ = 'joshua t. dunn'
1314
14- # COUNT_ENCODER BRANCH
15+
1516class CountEncoder (BaseEstimator , TransformerMixin ):
17+
1618 def __init__ (self , verbose = 0 , cols = None , drop_invariant = False ,
1719 return_df = True , handle_unknown = 'value' ,
1820 handle_missing = 'value' ,
@@ -118,6 +120,7 @@ def __init__(self, verbose=0, cols=None, drop_invariant=False,
118120 self .min_group_name = min_group_name
119121 self .combine_min_nan_groups = combine_min_nan_groups
120122 self .feature_names = None
123+ self .ordinal_encoder = None
121124
122125 self ._check_set_create_attrs ()
123126
@@ -157,9 +160,17 @@ def fit(self, X, y=None, **kwargs):
157160 else :
158161 self .cols = util .convert_cols_to_list (self .cols )
159162
163+ self .ordinal_encoder = OrdinalEncoder (
164+ verbose = self .verbose ,
165+ cols = self .cols ,
166+ handle_unknown = 'value' ,
167+ handle_missing = 'value'
168+ )
169+ self .ordinal_encoder = self .ordinal_encoder .fit (X )
170+ X_ordinal = self .ordinal_encoder .transform (X )
160171 self ._check_set_create_dict_attrs ()
161172
162- self ._fit_count_encode (X , y )
173+ self ._fit_count_encode (X_ordinal , y )
163174
164175 X_temp = self .transform (X , override_return_df = True )
165176 self .feature_names = list (X_temp .columns )
@@ -236,28 +247,11 @@ def _fit_count_encode(self, X_in, y):
236247 self .mapping = {}
237248
238249 for col in self .cols :
239- if X [col ].isnull ().any ():
240- if self ._handle_missing [col ] == 'error' :
241- raise ValueError (
242- 'Missing data found in column %s at fit time.'
243- % (col ,)
244- )
245-
246- elif self ._handle_missing [col ] not in ['value' , 'return_nan' , 'error' , None ]:
247- raise ValueError (
248- '%s key in `handle_missing` should be one of: '
249- ' `value`, `return_nan` and `error` not `%s`.'
250- % (col , str (self ._handle_missing [col ]))
251- )
252-
253- self .mapping [col ] = X [col ].value_counts (
254- normalize = self ._normalize [col ],
255- dropna = False
256- )
257-
258- self .mapping [col ].index = self .mapping [col ].index .astype (object )
259-
260-
250+ mapping_values = X [col ].value_counts (normalize = self ._normalize [col ])
251+ ordinal_encoding = [m ["mapping" ] for m in self .ordinal_encoder .mapping if m ["col" ] == col ][0 ]
252+ reversed_ordinal_enc = {v : k for k , v in ordinal_encoding .to_dict ().items ()}
253+ mapping_values .index = mapping_values .index .map (reversed_ordinal_enc )
254+ self .mapping [col ] = mapping_values
261255
262256 if self ._handle_missing [col ] == 'return_nan' :
263257 self .mapping [col ][np .NaN ] = np .NaN
@@ -273,15 +267,15 @@ def _transform_count_encode(self, X_in, y):
273267 X = X_in .copy (deep = True )
274268
275269 for col in self .cols :
276-
277- X [col ] = X .fillna (value = np .nan )[col ]
270+ # Treat None as np.nan
271+ X [col ] = pd .Series ([el if el is not None else np .NaN for el in X [col ]], index = X [col ].index )
272+ if self .handle_missing == "value" :
273+ if not util .is_category (X [col ].dtype ):
274+ X [col ] = X [col ].fillna (np .nan )
278275
279276 if self ._min_group_size is not None :
280277 if col in self ._min_group_categories .keys ():
281- X [col ] = (
282- X [col ].map (self ._min_group_categories [col ])
283- .fillna (X [col ])
284- )
278+ X [col ] = X [col ].map (self ._min_group_categories [col ]).fillna (X [col ])
285279
286280 X [col ] = X [col ].astype (object ).map (self .mapping [col ])
287281 if isinstance (self ._handle_unknown [col ], (int , np .integer )):
0 commit comments