Skip to content

Commit 9a8e11b

Browse files
fix issue #202 broken inverse transform in ordinal encoder
1 parent 87a9377 commit 9a8e11b

File tree

4 files changed

+202
-104
lines changed

4 files changed

+202
-104
lines changed

category_encoders/ordinal.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import category_encoders.utils as util
66
import warnings
7+
from typing import Dict, List, Union
78

89
__author__ = 'willmcginnis'
910

@@ -30,7 +31,7 @@ class OrdinalEncoder(util.BaseEncoder, util.UnsupervisedTransformerMixin):
3031
a mapping of class to label to use for the encoding, optional.
3132
the dict contains the keys 'col' and 'mapping'.
3233
the value of 'col' should be the feature name.
33-
the value of 'mapping' should be a dictionary of 'original_label' to 'encoded_label'.
34+
the value of 'mapping' should be a dictionary or pd.Series of 'original_label' to 'encoded_label'.
3435
example mapping: [
3536
{'col': 'col1', 'mapping': {None: 0, 'a': 1, 'b': 2}},
3637
{'col': 'col2', 'mapping': {None: 0, 'x': 1, 'y': 2}}
@@ -87,6 +88,8 @@ def __init__(self, verbose=0, mapping=None, cols=None, drop_invariant=False, ret
8788
super().__init__(verbose=verbose, cols=cols, drop_invariant=drop_invariant, return_df=return_df,
8889
handle_unknown=handle_unknown, handle_missing=handle_missing)
8990
self.mapping_supplied = mapping is not None
91+
if self.mapping_supplied:
92+
mapping = self._validate_supplied_mapping(mapping)
9093
self.mapping = mapping
9194

9295
@property
@@ -237,3 +240,28 @@ def ordinal_encoding(X_in, mapping=None, cols=None, handle_unknown='value', hand
237240
mapping_out.append({'col': col, 'mapping': data, 'data_type': X[col].dtype}, )
238241

239242
return X, mapping_out
243+
244+
def _validate_supplied_mapping(self, supplied_mapping: List[Dict[str, Union[Dict, pd.Series]]]) -> List[Dict[str, pd.Series]]:
245+
"""
246+
validate the supplied mapping and convert the actual mapping per column to a pandas series.
247+
:param supplied_mapping: mapping as list of dicts. They actual mapping can be either a dict or pd.Series
248+
:return: the mapping with all actual mappings being pandas series
249+
"""
250+
msg = "Invalid supplied mapping, must be of type List[Dict[str, Union[Dict, pd.Series]]]." \
251+
"For an example refer to the documentation"
252+
if not isinstance(supplied_mapping, list):
253+
raise ValueError(msg)
254+
for mapping_el in supplied_mapping:
255+
if not isinstance(mapping_el, dict):
256+
raise ValueError(msg)
257+
if "col" not in mapping_el:
258+
raise KeyError("Mapping must contain a key 'col' for each column to encode")
259+
if "mapping" not in mapping_el:
260+
raise KeyError("Mapping must contain a key 'mapping' for each column to encode")
261+
mapping = mapping_el["mapping"]
262+
if isinstance(mapping_el, dict):
263+
# convert to dict in order to standardise
264+
mapping_el["mapping"] = pd.Series(mapping)
265+
if "data_type" not in mapping_el:
266+
mapping_el["data_type"] = mapping_el["mapping"].index.dtype
267+
return supplied_mapping

category_encoders/rankhot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def inverse_transform(self, X_in):
195195
orig_dtype = ordinal_mapping.get("data_type")
196196
reencode2 = reencode.replace(inv_map).astype(orig_dtype)
197197
if np.any(reencode2[:] == 0):
198-
reencode2[reencode2[:] == 0] = "None"
198+
reencode2[reencode2[:] == 0] = np.nan
199199

200200
X = self.create_dataframe(X, reencode2, col)
201201

0 commit comments

Comments
 (0)