Skip to content

Commit 2e33bc0

Browse files
committed
validate constant and invert params
1 parent 3c3b211 commit 2e33bc0

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

rdt/transformers/numerical.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -668,12 +668,19 @@ def __init__(
668668
self,
669669
missing_value_replacement='mean',
670670
missing_value_generation='random',
671-
constant: float = 0,
671+
constant: float = 0.0,
672672
invert: bool = False,
673673
learn_rounding_scheme: bool = False,
674674
):
675-
self.constant = constant
676-
self.invert = invert
675+
if isinstance(constant, float):
676+
self.constant = constant
677+
else:
678+
raise ValueError('The constant parameter must be a float.')
679+
if isinstance(invert, bool):
680+
self.invert = invert
681+
else:
682+
raise ValueError('The invert parameter must be a bool.')
683+
677684
super().__init__(
678685
missing_value_replacement=missing_value_replacement,
679686
missing_value_generation=missing_value_generation,

tests/integration/test_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
'FloatFormatter': {'missing_value_generation': 'from_column'},
2727
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
2828
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
29-
'LogScaler': {'constant': INT64_MIN, 'missing_value_generation': 'from_column'},
29+
'LogScaler': {'constant': float(INT64_MIN), 'missing_value_generation': 'from_column'},
3030
}
3131

3232
# Mapping of rdt sdtype to dtype

tests/integration/transformers/test_numerical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def test_missing_value_generation_random(self):
608608
missing_value_generation='random',
609609
missing_value_replacement='mode',
610610
invert=True,
611-
constant=3,
611+
constant=3.0,
612612
)
613613
expected = pd.DataFrame({'test': [np.nan, 1.5, 1.5, 1.5]})
614614

tests/unit/transformers/test_numerical.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,14 @@ def test___init__constant(self):
18881888
assert ls_set.constant == 2.5
18891889
assert ls_default.constant == 0.0
18901890

1891+
def test__init__validates_constant(self):
1892+
"""Test __init__ validates constat parameter."""
1893+
# Setup
1894+
message = 'The constant parameter must be a float.'
1895+
# Run and Assert
1896+
with pytest.raises(ValueError, match=message):
1897+
LogScaler(constant=2)
1898+
18911899
def test___init__invert(self):
18921900
"""Test invert parameter is set as an attribute."""
18931901
# Setup
@@ -1898,6 +1906,14 @@ def test___init__invert(self):
18981906
assert ls_set.invert
18991907
assert not ls_default.invert
19001908

1909+
def test__init__validates_invert(self):
1910+
"""Test __init__ validates constat parameter."""
1911+
# Setup
1912+
message = 'The invert parameter must be a bool.'
1913+
# Run and Assert
1914+
with pytest.raises(ValueError, match=message):
1915+
LogScaler(invert=2)
1916+
19011917
def test__validate_data(self):
19021918
"""Test the ``_validate_data`` method"""
19031919
# Setup
@@ -1987,7 +2003,7 @@ def test__transform(self):
19872003
def test__transform_invert(self):
19882004
"""Test the ``_transform`` method with ``invert=True``"""
19892005
# Setup
1990-
ls = LogScaler(constant=3, invert=True, missing_value_replacement='from_column')
2006+
ls = LogScaler(constant=3.0, invert=True, missing_value_replacement='from_column')
19912007
ls._validate_data = Mock()
19922008
ls.null_transformer = NullTransformer(
19932009
missing_value_replacement='mean', missing_value_generation='from_column'
@@ -2027,7 +2043,7 @@ def test__transform_null_values(self):
20272043
def test__transform_null_values_invert(self):
20282044
"""Test the ``_transform`` method with ``invert=True``"""
20292045
# Setup
2030-
ls = LogScaler(constant=3, invert=True, missing_value_replacement='from_column')
2046+
ls = LogScaler(constant=3.0, invert=True, missing_value_replacement='from_column')
20312047
ls._validate_data = Mock()
20322048
ls.null_transformer = NullTransformer(
20332049
missing_value_replacement='mean', missing_value_generation='from_column'
@@ -2117,7 +2133,7 @@ def test__reverse_transform_invert(self):
21172133
[0, 0, 1.0],
21182134
]).T
21192135
expected = pd.Series([0.1, 1.0, np.nan])
2120-
ls = LogScaler(constant=3, invert=True)
2136+
ls = LogScaler(constant=3.0, invert=True)
21212137
ls.null_transformer = NullTransformer(
21222138
missing_value_replacement='mean',
21232139
missing_value_generation='from_column',
@@ -2158,7 +2174,7 @@ def test__reverse_transform_invert_missing_value_generation(self):
21582174
# Setup
21592175
data = np.array([1.06471, 0.69315, 0])
21602176
expected = pd.Series([0.1, 1.0, 2.0])
2161-
ls = LogScaler(constant=3, invert=True)
2177+
ls = LogScaler(constant=3.0, invert=True)
21622178
ls.null_transformer = NullTransformer(None, missing_value_generation='random')
21632179

21642180
# Run

0 commit comments

Comments
 (0)