44import pandas as pd
55import category_encoders .utils as util
66import warnings
7+ from typing import Dict , List , Union
78
89__author__ = 'willmcginnis'
910
@@ -30,7 +31,7 @@ class OrdinalEncoder(util.BaseEncoder, util.UnsupervisedTransformerMixin):
3031 a mapping of class to label to use for the encoding, optional.
3132 the dict contains the keys 'col' and 'mapping'.
3233 the value of 'col' should be the feature name.
33- the value of 'mapping' should be a dictionary of 'original_label' to 'encoded_label'.
34+ the value of 'mapping' should be a dictionary or pd.Series of 'original_label' to 'encoded_label'.
3435 example mapping: [
3536 {'col': 'col1', 'mapping': {None: 0, 'a': 1, 'b': 2}},
3637 {'col': 'col2', 'mapping': {None: 0, 'x': 1, 'y': 2}}
@@ -87,6 +88,8 @@ def __init__(self, verbose=0, mapping=None, cols=None, drop_invariant=False, ret
8788 super ().__init__ (verbose = verbose , cols = cols , drop_invariant = drop_invariant , return_df = return_df ,
8889 handle_unknown = handle_unknown , handle_missing = handle_missing )
8990 self .mapping_supplied = mapping is not None
91+ if self .mapping_supplied :
92+ mapping = self ._validate_supplied_mapping (mapping )
9093 self .mapping = mapping
9194
9295 @property
@@ -237,3 +240,28 @@ def ordinal_encoding(X_in, mapping=None, cols=None, handle_unknown='value', hand
237240 mapping_out .append ({'col' : col , 'mapping' : data , 'data_type' : X [col ].dtype }, )
238241
239242 return X , mapping_out
243+
244+ def _validate_supplied_mapping (self , supplied_mapping : List [Dict [str , Union [Dict , pd .Series ]]]) -> List [Dict [str , pd .Series ]]:
245+ """
246+ validate the supplied mapping and convert the actual mapping per column to a pandas series.
247+ :param supplied_mapping: mapping as list of dicts. They actual mapping can be either a dict or pd.Series
248+ :return: the mapping with all actual mappings being pandas series
249+ """
250+ msg = "Invalid supplied mapping, must be of type List[Dict[str, Union[Dict, pd.Series]]]." \
251+ "For an example refer to the documentation"
252+ if not isinstance (supplied_mapping , list ):
253+ raise ValueError (msg )
254+ for mapping_el in supplied_mapping :
255+ if not isinstance (mapping_el , dict ):
256+ raise ValueError (msg )
257+ if "col" not in mapping_el :
258+ raise KeyError ("Mapping must contain a key 'col' for each column to encode" )
259+ if "mapping" not in mapping_el :
260+ raise KeyError ("Mapping must contain a key 'mapping' for each column to encode" )
261+ mapping = mapping_el ["mapping" ]
262+ if isinstance (mapping_el , dict ):
263+ # convert to dict in order to standardise
264+ mapping_el ["mapping" ] = pd .Series (mapping )
265+ if "data_type" not in mapping_el :
266+ mapping_el ["data_type" ] = mapping_el ["mapping" ].index .dtype
267+ return supplied_mapping
0 commit comments