diff --git a/rdt/transformers/categorical.py b/rdt/transformers/categorical.py index 210c86e06..2fdfec112 100644 --- a/rdt/transformers/categorical.py +++ b/rdt/transformers/categorical.py @@ -59,12 +59,12 @@ def _order_categories(self, unique_data): nans = pd.isna(unique_data) if self.order_by == 'alphabetical': # pylint: disable=invalid-unary-operand-type - if any(map(lambda item: not isinstance(item, str), unique_data[~nans])): # noqa: C417 + if any(not isinstance(item, str) for item in unique_data[~nans]): raise TransformerInputError( "The data must be of type string if order_by is 'alphabetical'." ) elif self.order_by == 'numerical_value': - if not np.issubdtype(unique_data.dtype.type, np.number): + if any(not np.issubdtype(type(item), np.number) for item in unique_data[~nans]): raise TransformerInputError( "The data must be numerical if order_by is 'numerical_value'." ) diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index cc67c9ec3..63e1d7201 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -176,14 +176,14 @@ def strings_from_regex(regex, max_repeat=16): def _fill_nan_with_none_series(data): - sentinel = object() dtype = data.dtype if isinstance(dtype, pd.CategoricalDtype): + sentinel = object() data = data.cat.add_categories([sentinel]) data = data.fillna(sentinel).replace({sentinel: None}) return pd.Series(pd.Categorical(data, categories=dtype.categories), index=data.index) - return data.fillna(sentinel).replace({sentinel: None}) + return data.astype('object').where(~data.isna(), None) def fill_nan_with_none(data): diff --git a/tests/unit/transformers/test_categorical.py b/tests/unit/transformers/test_categorical.py index 8f99f1274..f084d782f 100644 --- a/tests/unit/transformers/test_categorical.py +++ b/tests/unit/transformers/test_categorical.py @@ -198,6 +198,24 @@ def test__fit(self): assert transformer.frequencies == expected_frequencies assert transformer.intervals == expected_intervals + def test_fit_with_nullable_integer_dtype(self): + """Test that the ``fit`` method works with nullable integer columns.""" + # Setup + data = pd.DataFrame({'example': [1, 2, 3, None]}, dtype='Int64') + transformer = UniformEncoder() + + # Run + transformer.fit(data=data, column='example') + + # Assert + expected_frequencies = { + 1: 0.25, + 2: 0.25, + 3: 0.25, + None: 0.25, + } + assert transformer.frequencies == expected_frequencies + def test__set_fitted_parameters(self): """Test the ``_set_fitted_parameters`` method.""" # Setup