Skip to content

Commit 8725059

Browse files
Merge pull request #454 from PaulWestenthanner/fix-453
Fixes #453. Categorical targets.
2 parents 06fcf04 + a0aa54e commit 8725059

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
v.2.8.1
2+
=======
3+
4+
* Fix: Support and test string targets and `pd.Categorical` targets.
5+
* Fix: Docs typo.
6+
17
v.2.8.0
28
=======
39

category_encoders/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def fit(self, X: X_type, y: y_type | None = None, **kwargs):
457457
if self.__sklearn_tags__().target_tags.required:
458458
if not is_numeric_dtype(y):
459459
self.lab_encoder_ = LabelEncoder()
460-
y = self.lab_encoder_.fit_transform(y)
460+
y = pd.Series(self.lab_encoder_.fit_transform(y), index=y.index)
461461
else:
462462
self.lab_encoder_ = None
463463

@@ -621,7 +621,7 @@ def transform(self, X: X_type, y: y_type | None = None, override_return_df: bool
621621
X, y = convert_inputs(X, y, deep=True)
622622
self._check_transform_inputs(X)
623623
if y is not None and self.lab_encoder_ is not None:
624-
y = self.lab_encoder_.transform(y)
624+
y = pd.Series(self.lab_encoder_.transform(y), index=y.index)
625625

626626
if not list(self.cols):
627627
return X

tests/test_encoders.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,36 @@ def test_types(self):
436436
encoder = getattr(encoders, encoder_name)()
437437
encoder.fit_transform(X, y)
438438

439+
def test_string_targets(self):
440+
"""Test encoders with targets of type pd.Categorical or string."""
441+
X = pd.DataFrame({'feature': ['A', 'B', 'A', 'C']})
442+
y_string = pd.Series(['yes', 'no', 'yes', 'no'])
443+
444+
for encoder_name in encoders.__all__:
445+
with self.subTest(encoder_name=encoder_name):
446+
enc = getattr(encoders, encoder_name)()
447+
448+
# Test with string target
449+
enc.fit(X, y_string)
450+
transformed = enc.transform(X)
451+
th.verify_numeric(transformed)
452+
self.assertEqual(len(transformed), 4)
453+
def test_categorical_targets(self):
454+
"""Test encoders with targets of type pd.Categorical or string."""
455+
X = pd.DataFrame({'feature': ['A', 'B', 'A', 'C']})
456+
y_categorical = pd.Categorical([1, 0, 1, 0])
457+
458+
for encoder_name in encoders.__all__:
459+
with self.subTest(encoder_name=encoder_name):
460+
enc = getattr(encoders, encoder_name)()
461+
462+
# Test with pd.Categorical target
463+
enc.fit(X, y_categorical)
464+
transformed = enc.transform(X)
465+
th.verify_numeric(transformed)
466+
self.assertEqual(len(transformed), 4)
467+
468+
439469
def test_preserve_column_order(self):
440470
"""Test that the encoder preserves the column order."""
441471
binary_cat_example = pd.DataFrame(

0 commit comments

Comments
 (0)