Skip to content

Commit f69e2cb

Browse files
committed
improve fill_nan_with_none
1 parent 837186f commit f69e2cb

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

rdt/transformers/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,22 +175,28 @@ def strings_from_regex(regex, max_repeat=16):
175175
return _from_generators(generators, max_repeat), np.prod(sizes, dtype=np.complex128).real
176176

177177

178+
def _fill_nan_with_none_series(data):
179+
sentinel = object()
180+
if isinstance(data.dtype, pd.CategoricalDtype):
181+
data = data.cat.add_categories([sentinel])
182+
183+
return data.fillna(sentinel).replace({sentinel: None})
184+
185+
178186
def fill_nan_with_none(data):
179187
"""Replace all nan values with None.
180188
181189
Args:
182-
data (pd.Series)
190+
data (pd.DataFrame or pd.Series)
183191
184192
Returns:
185193
data:
186194
Original data with nan values replaced by None.
187195
"""
188-
sentinel = object()
189-
if isinstance(data.dtype, pd.CategoricalDtype):
190-
return data.where(~pd.isna(data), None)
196+
if isinstance(data, pd.DataFrame):
197+
return data.apply(_fill_nan_with_none_series)
191198

192-
data = data.fillna(sentinel)
193-
return data.replace({sentinel: None})
199+
return _fill_nan_with_none_series(data)
194200

195201

196202
def flatten_column_list(column_list):

tests/unit/transformers/test_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_any,
1717
_cast_to_type,
1818
_extract_timezone_from_a_string,
19+
_fill_nan_with_none_series,
1920
_get_utc_offset,
2021
_handle_enforce_uniqueness_and_cardinality_rule,
2122
_max_repeat,
@@ -167,6 +168,43 @@ def test_fill_nan_with_none_no_warning():
167168
pd.testing.assert_series_equal(result, expected)
168169

169170

171+
def test__fill_nan_with_none_series():
172+
"""Test the ``_fill_nan_with_none_series`` method."""
173+
# Setup
174+
series = pd.Series([1.0, 2.0, 3.0, np.nan], dtype='object')
175+
categorical_serie = pd.Series(['a', 'b', 'c', 'd', np.nan], dtype='category')
176+
177+
# Run
178+
result = _fill_nan_with_none_series(series)
179+
result_categorical = _fill_nan_with_none_series(categorical_serie)
180+
181+
# Assert
182+
expected_result = pd.Series([1.0, 2.0, 3.0, None], dtype='object')
183+
pd.testing.assert_series_equal(result, expected_result)
184+
expected_result_categorical = pd.Series(['a', 'b', 'c', 'd', None], dtype='category')
185+
pd.testing.assert_series_equal(result_categorical, expected_result_categorical)
186+
187+
188+
def test_fill_nan_with_none_series():
189+
"""Test the `fill_nan_with_none_series` function."""
190+
# Setup
191+
series = pd.Series([1.0, 2.0, 3.0, np.nan], dtype='object')
192+
data = pd.DataFrame({'col1': series})
193+
data_2 = pd.DataFrame({'col1': series, 'col2': ['a', 'b', 'c', np.nan]})
194+
195+
# Run
196+
result_series = _fill_nan_with_none_series(series)
197+
result_data = fill_nan_with_none(data)
198+
result_data_2 = fill_nan_with_none(data_2)
199+
200+
# Assert
201+
expected_result = pd.Series([1.0, 2.0, 3.0, None], dtype='object')
202+
expected_result_data_2 = pd.DataFrame({'col1': expected_result, 'col2': ['a', 'b', 'c', None]})
203+
pd.testing.assert_series_equal(result_series, expected_result)
204+
pd.testing.assert_frame_equal(result_data, pd.DataFrame({'col1': expected_result}))
205+
pd.testing.assert_frame_equal(result_data_2, expected_result_data_2)
206+
207+
170208
def test_check_nan_in_transform():
171209
"""Test ``check_nan_in_transform`` method.
172210

0 commit comments

Comments
 (0)