66import pandas as pd
77import numpy as np
88import sklearn .base
9- from pandas .api .types import is_object_dtype , is_string_dtype
9+ from pandas .api .types import is_object_dtype , is_string_dtype , is_numeric_dtype
1010from pandas .core .dtypes .dtypes import CategoricalDtype
1111from sklearn .base import BaseEstimator , TransformerMixin
1212from sklearn .exceptions import NotFittedError
1313from typing import Dict , List , Optional , Union
1414from scipy .sparse import csr_matrix
15+ from sklearn .preprocessing import LabelEncoder
1516
1617__author__ = 'willmcginnis'
1718
@@ -294,11 +295,18 @@ def fit(self, X, y=None, **kwargs):
294295 Returns self.
295296
296297 """
297- self ._check_fit_inputs (X , y )
298298 X , y = convert_inputs (X , y )
299+ self ._check_fit_inputs (X , y )
299300 self .feature_names_in_ = X .columns .tolist ()
300301 self .n_features_in_ = len (self .feature_names_in_ )
301302
303+ if self ._get_tags ().get ('supervised_encoder' ):
304+ if not is_numeric_dtype (y ):
305+ self .lab_encoder_ = LabelEncoder ()
306+ y = self .lab_encoder_ .fit_transform (y )
307+ else :
308+ self .lab_encoder_ = None
309+
302310 self ._dim = X .shape [1 ]
303311 self ._determine_fit_columns (X )
304312
@@ -324,8 +332,12 @@ def fit(self, X, y=None, **kwargs):
324332 return self
325333
326334 def _check_fit_inputs (self , X , y ):
327- if self ._get_tags ().get ('supervised_encoder' ) and y is None :
328- raise ValueError ('Supervised encoders need a target for the fitting. The target cannot be None' )
335+ if self ._get_tags ().get ('supervised_encoder' ):
336+ if y is None :
337+ raise ValueError ('Supervised encoders need a target for the fitting. The target cannot be None' )
338+ else :
339+ if y .isna ().any (): # Target column should never have missing values
340+ raise ValueError ("The target column y must not contain missing values." )
329341
330342 def _check_transform_inputs (self , X ):
331343 if self .handle_missing == 'error' :
@@ -435,6 +447,8 @@ def transform(self, X, y=None, override_return_df=False):
435447 # first check the type
436448 X , y = convert_inputs (X , y , deep = True )
437449 self ._check_transform_inputs (X )
450+ if y is not None and self .lab_encoder_ is not None :
451+ y = self .lab_encoder_ .transform (y )
438452
439453 if not list (self .cols ):
440454 return X
0 commit comments