Skip to content

Commit d28b409

Browse files
added more tests
1 parent 9a8e11b commit d28b409

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

category_encoders/ordinal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def ordinal_encoding(X_in, mapping=None, cols=None, handle_unknown='value', hand
241241

242242
return X, mapping_out
243243

244-
def _validate_supplied_mapping(self, supplied_mapping: List[Dict[str, Union[Dict, pd.Series]]]) -> List[Dict[str, pd.Series]]:
244+
def _validate_supplied_mapping(self, supplied_mapping: List[Dict[str, Union[str, Dict, pd.Series]]]) -> List[Dict[str, Union[str, pd.Series]]]:
245245
"""
246246
validate the supplied mapping and convert the actual mapping per column to a pandas series.
247247
:param supplied_mapping: mapping as list of dicts. They actual mapping can be either a dict or pd.Series

tests/test_ordinal.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,27 @@ def test_inverse_with_mapping(self):
316316
return_df=True,
317317
)
318318
df[categoricals] = enc.fit_transform(df[categoricals])
319-
print("ord mapping after fit")
320-
print(enc.mapping)
321319
recovered = enc.inverse_transform(df[categoricals])
322320
pd.testing.assert_frame_equal(X[categoricals], recovered)
321+
322+
def test_validate_mapping(self):
323+
custom_mapping = [
324+
{
325+
"col": "col1",
326+
"mapping": {np.NaN: 0, "a": 1, "b": 2},
327+
}, # The mapping from the documentation
328+
{"col": "col2", "mapping": {np.NaN: -3, "x": 11, "y": 2}},
329+
]
330+
expected_valid_mapping = [
331+
{
332+
"col": "col1",
333+
"mapping": pd.Series({np.NaN: 0, "a": 1, "b": 2}),
334+
}, # The mapping from the documentation
335+
{"col": "col2", "mapping": pd.Series({np.NaN: -3, "x": 11, "y": 2})},
336+
]
337+
enc = encoders.OrdinalEncoder()
338+
actual_valid_mapping = enc._validate_supplied_mapping(custom_mapping)
339+
self.assertEqual(len(actual_valid_mapping), len(expected_valid_mapping))
340+
for idx in range(len(actual_valid_mapping)):
341+
self.assertEqual(actual_valid_mapping[idx]["col"], expected_valid_mapping[idx]["col"])
342+
pd.testing.assert_series_equal(actual_valid_mapping[idx]["mapping"], expected_valid_mapping[idx]["mapping"])

0 commit comments

Comments
 (0)