Skip to content

Commit 68aacd6

Browse files
committed
fix tests
1 parent f69e2cb commit 68aacd6

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

rdt/transformers/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,11 @@ def strings_from_regex(regex, max_repeat=16):
177177

178178
def _fill_nan_with_none_series(data):
179179
sentinel = object()
180-
if isinstance(data.dtype, pd.CategoricalDtype):
180+
dtype = data.dtype
181+
if isinstance(dtype, pd.CategoricalDtype):
181182
data = data.cat.add_categories([sentinel])
183+
data = data.fillna(sentinel).replace({sentinel: None})
184+
return pd.Series(pd.Categorical(data, categories=dtype.categories), index=data.index)
182185

183186
return data.fillna(sentinel).replace({sentinel: None})
184187

tests/unit/transformers/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def test__fill_nan_with_none_series():
181181
# Assert
182182
expected_result = pd.Series([1.0, 2.0, 3.0, None], dtype='object')
183183
pd.testing.assert_series_equal(result, expected_result)
184-
expected_result_categorical = pd.Series(['a', 'b', 'c', 'd', None], dtype='category')
184+
expected_result_categorical = pd.Series(
185+
pd.Categorical(['a', 'b', 'c', 'd', None], categories=['a', 'b', 'c', 'd'])
186+
)
185187
pd.testing.assert_series_equal(result_categorical, expected_result_categorical)
186188

187189

0 commit comments

Comments
 (0)