Skip to content

Commit 1c29b89

Browse files
committed
fix test
1 parent 98435e2 commit 1c29b89

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

rdt/transformers/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,11 @@ def strings_from_regex(regex, max_repeat=16):
177177

178178
def _fill_nan_with_none_series(data):
179179
sentinel = object()
180-
if isinstance(data.dtype, pd.CategoricalDtype):
180+
dtype = data.dtype
181+
if isinstance(dtype, pd.CategoricalDtype):
181182
data = data.cat.add_categories([sentinel])
182183

183-
return data.fillna(sentinel).replace({sentinel: None})
184+
return data.fillna(sentinel).replace({sentinel: None}).astype(dtype)
184185

185186

186187
def fill_nan_with_none(data):

tests/unit/transformers/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def test__fill_nan_with_none_series():
181181
# Assert
182182
expected_result = pd.Series([1.0, 2.0, 3.0, None], dtype='object')
183183
pd.testing.assert_series_equal(result, expected_result)
184-
expected_result_categorical = pd.Series(['a', 'b', 'c', 'd', None], dtype='object')
185-
pd.testing.assert_series_equal(
186-
result_categorical, expected_result_categorical, check_dtype=False
184+
expected_result_categorical = pd.Series(
185+
pd.Categorical(['a', 'b', 'c', 'd', None], categories=['a', 'b', 'c', 'd'])
187186
)
187+
pd.testing.assert_series_equal(result_categorical, expected_result_categorical)
188188

189189

190190
def test_fill_nan_with_none_series():

0 commit comments

Comments
 (0)