Skip to content

Commit 5cfe457

Browse files
Merge pull request #423 from tvdboom/add_dtypes_default_fit_cols
add category and string dtypes to default fit columns
2 parents 80b4a9b + 721323b commit 5cfe457

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

category_encoders/quantile_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def fit(self, X, y):
281281
self.n_features_in_ = len(self.feature_names_in_)
282282

283283
if self.use_default_cols:
284-
self.cols = util.get_obj_cols(X)
284+
self.cols = util.get_categorical_cols(X)
285285
else:
286286
self.cols = util.convert_cols_to_list(self.cols)
287287

category_encoders/utils.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import numpy as np
88
import sklearn.base
9+
from pandas.api.types import is_object_dtype, is_string_dtype
910
from pandas.core.dtypes.dtypes import CategoricalDtype
1011
from sklearn.base import BaseEstimator, TransformerMixin
1112
from sklearn.exceptions import NotFittedError
@@ -32,14 +33,15 @@ def convert_cols_to_list(cols):
3233
return cols
3334

3435

35-
def get_obj_cols(df):
36+
def get_categorical_cols(df):
3637
"""
37-
Returns names of 'object' columns in the DataFrame.
38+
Returns names of categorical columns in the DataFrame. These
39+
include columns of types: object, category, string, string[pyarrow].
3840
"""
3941
obj_cols = []
40-
for idx, dt in enumerate(df.dtypes):
41-
if dt == 'object' or is_category(dt):
42-
obj_cols.append(df.columns.values[idx])
42+
for col, dtype in df.dtypes.items():
43+
if is_object_dtype(dtype) or is_category(dtype) or is_string_dtype(dtype):
44+
obj_cols.append(col)
4345

4446
if not obj_cols:
4547
print("Warning: No categorical columns found. Calling 'transform' will only return input data.")
@@ -99,10 +101,11 @@ def convert_input(X, columns=None, deep=False, index=None):
99101
if isinstance(X, pd.Series):
100102
X = pd.DataFrame(X, copy=deep)
101103
else:
102-
if columns is not None and np.size(X,1) != len(columns):
104+
if columns is not None and np.size(X, 1) != len(columns):
103105
raise ValueError('The count of the column names does not correspond to the count of the columns')
104106
if isinstance(X, list):
105-
X = pd.DataFrame(X, columns=columns, copy=deep, index=index) # lists are always copied, but for consistency, we still pass the argument
107+
X = pd.DataFrame(X, columns=columns, copy=deep,
108+
index=index) # lists are always copied, but for consistency, we still pass the argument
106109
elif isinstance(X, (np.generic, np.ndarray)):
107110
X = pd.DataFrame(X, columns=columns, copy=deep, index=index)
108111
elif isinstance(X, csr_matrix):
@@ -126,34 +129,34 @@ def convert_input_vector(y, index):
126129
if isinstance(y, pd.Series):
127130
return y
128131
elif isinstance(y, np.ndarray):
129-
if len(np.shape(y))==1: # vector
132+
if len(np.shape(y)) == 1: # vector
130133
return pd.Series(y, name='target', index=index)
131-
elif len(np.shape(y))==2 and np.shape(y)[0]==1: # single row in a matrix
134+
elif len(np.shape(y)) == 2 and np.shape(y)[0] == 1: # single row in a matrix
132135
return pd.Series(y[0, :], name='target', index=index)
133-
elif len(np.shape(y))==2 and np.shape(y)[1]==1: # single column in a matrix
136+
elif len(np.shape(y)) == 2 and np.shape(y)[1] == 1: # single column in a matrix
134137
return pd.Series(y[:, 0], name='target', index=index)
135138
else:
136139
raise ValueError(f'Unexpected input shape: {np.shape(y)}')
137140
elif np.isscalar(y):
138141
return pd.Series([y], name='target', index=index)
139142
elif isinstance(y, list):
140-
if len(y)==0: # empty list
143+
if len(y) == 0: # empty list
141144
return pd.Series(y, name='target', index=index, dtype=float)
142-
elif len(y)>0 and not isinstance(y[0], list): # vector
145+
elif len(y) > 0 and not isinstance(y[0], list): # vector
143146
return pd.Series(y, name='target', index=index)
144-
elif len(y)>0 and isinstance(y[0], list) and len(y[0])==1: # single row in a matrix
147+
elif len(y) > 0 and isinstance(y[0], list) and len(y[0]) == 1: # single row in a matrix
145148
flatten = lambda y: [item for sublist in y for item in sublist]
146149
return pd.Series(flatten(y), name='target', index=index)
147-
elif len(y)==1 and len(y[0])==0 and isinstance(y[0], list): # single empty column in a matrix
150+
elif len(y) == 1 and len(y[0]) == 0 and isinstance(y[0], list): # single empty column in a matrix
148151
return pd.Series(y[0], name='target', index=index, dtype=float)
149-
elif len(y)==1 and isinstance(y[0], list): # single column in a matrix
152+
elif len(y) == 1 and isinstance(y[0], list): # single column in a matrix
150153
return pd.Series(y[0], name='target', index=index, dtype=type(y[0][0]))
151154
else:
152155
raise ValueError('Unexpected input shape')
153156
elif isinstance(y, pd.DataFrame):
154-
if len(list(y))==0: # empty DataFrame
157+
if len(list(y)) == 0: # empty DataFrame
155158
return pd.Series(name='target', index=index, dtype=float)
156-
if len(list(y))==1: # a single column
159+
if len(list(y)) == 1: # a single column
157160
return y.iloc[:, 0]
158161
else:
159162
raise ValueError(f'Unexpected input shape: {y.shape}')
@@ -274,7 +277,7 @@ def __init__(self, verbose=0, cols=None, drop_invariant=False, return_df=True,
274277

275278
def fit(self, X, y=None, **kwargs):
276279
"""Fits the encoder according to X and y.
277-
280+
278281
Parameters
279282
----------
280283
@@ -355,7 +358,7 @@ def _determine_fit_columns(self, X: pd.DataFrame) -> None:
355358
"""
356359
# if columns aren't passed, just use every string column
357360
if self.use_default_cols:
358-
self.cols = get_obj_cols(X)
361+
self.cols = get_categorical_cols(X)
359362
else:
360363
self.cols = convert_cols_to_list(self.cols)
361364

tests/test_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from unittest import TestCase # or `from unittest import ...` if on Python 3.4+
2-
from category_encoders.utils import convert_input_vector, convert_inputs
2+
from category_encoders.utils import convert_input_vector, convert_inputs, get_categorical_cols
33
import pandas as pd
44
import numpy as np
55

@@ -114,3 +114,9 @@ def test_convert_inputs(self):
114114

115115
# shape mismatch
116116
self.assertRaises(ValueError, convert_inputs, barray, [1, 2, 3, 4])
117+
118+
def test_get_categorical_cols(self):
119+
df = pd.DataFrame({"col": ["a", "b"]})
120+
self.assertEqual(get_categorical_cols(df.astype("object")), ["col"])
121+
self.assertEqual(get_categorical_cols(df.astype("category")), ["col"])
122+
self.assertEqual(get_categorical_cols(df.astype("string")), ["col"])

0 commit comments

Comments
 (0)