Skip to content

Commit 75de566

Browse files
fixed tests
1 parent e673b7c commit 75de566

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

category_encoders/gray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ def _fit(self, X, y=None, **kwargs):
9696
col = col_to_encode["col"]
9797
bin_mapping = col_to_encode["mapping"]
9898
n_cols_out = bin_mapping.shape[1]
99-
map_null = bin_mapping[bin_mapping.index < 0]
100-
map_non_null = bin_mapping[bin_mapping.index >= 0].copy()
99+
null_cond = (bin_mapping.index < 0) | (bin_mapping.isnull().all(1))
100+
map_null = bin_mapping[null_cond]
101+
map_non_null = bin_mapping[~null_cond].copy()
101102
ordinal_mapping = [m for m in self.ordinal_encoder.mapping if m.get("col") == col]
102103
if len(ordinal_mapping) != 1:
103104
raise ValueError("Cannot find ordinal encoder mapping of Gray encoder")

tests/test_encoders.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -624,13 +624,17 @@ def test_metamorphic(self):
624624
result5 = enc5.fit_transform(x5, y)
625625
self.assertTrue((result1.values == result5.values).all())
626626

627-
enc6 = getattr(encoders, encoder_name)()
628-
result6 = enc6.fit_transform(x6, y)
629-
self.assertTrue((result1.values == result6.values).all())
630-
631-
enc7 = getattr(encoders, encoder_name)()
632-
result7 = enc7.fit_transform(x7, y)
633-
self.assertTrue((result1.values == result7.values).all())
627+
# gray encoder re-orders inputs so that nan is last, hence the output is changed
628+
if encoder_name != "GrayEncoder":
629+
enc6 = getattr(encoders, encoder_name)()
630+
result6 = enc6.fit_transform(x6, y)
631+
self.assertTrue((result1.values == result6.values).all())
632+
633+
# gray encoder actually does re-order inputs
634+
if encoder_name != "GrayEncoder":
635+
enc7 = getattr(encoders, encoder_name)()
636+
result7 = enc7.fit_transform(x7, y)
637+
self.assertTrue((result1.values == result7.values).all())
634638

635639
# Arguments
636640
enc9 = getattr(encoders, encoder_name)(return_df=False)

0 commit comments

Comments
 (0)