Skip to content

Commit 6539a53

Browse files
committed
Add input validation + tests for checking rounding scheme
1 parent fc272bf commit 6539a53

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

rdt/transformers/numerical.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,11 @@ def __init__(
670670
max_value=1.0,
671671
learn_rounding_scheme=False,
672672
):
673+
if not (isinstance(min_value, int) or isinstance(min_value, float)) or not (
674+
isinstance(max_value, int) or isinstance(max_value, float)
675+
):
676+
error_msg = 'The min_value and max_value must be of type int or float.'
677+
raise TransformerInputError(error_msg)
673678
if min_value == max_value:
674679
error_msg = 'The min_value and max_value for the logit function cannot be equal.'
675680
raise TransformerInputError(error_msg)

tests/integration/transformers/test_numerical.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
ClusterBasedNormalizer,
88
FloatFormatter,
99
GaussianNormalizer,
10+
LogitScaler,
1011
)
1112

1213

@@ -560,3 +561,59 @@ def test_out_of_bounds_reverse_transform(self):
560561

561562
# Assert
562563
assert isinstance(reverse, pd.DataFrame)
564+
565+
566+
class TestLogitScaler:
567+
def test_learn_rounding(self):
568+
"""Test that transformer learns rounding scheme from data."""
569+
# Setup
570+
data = pd.DataFrame({'test': [0.2, np.nan, 1.0]})
571+
transformer = LogitScaler(
572+
missing_value_generation=None,
573+
missing_value_replacement='mean',
574+
learn_rounding_scheme=True,
575+
)
576+
expected = pd.DataFrame({'test': [0.2, 0.6, 1.0]})
577+
578+
# Run
579+
transformer.fit(data, 'test')
580+
transformed = transformer.transform(data)
581+
reversed_values = transformer.reverse_transform(transformed)
582+
583+
# Assert
584+
np.testing.assert_array_equal(reversed_values, expected)
585+
586+
def test_missing_value_generation_from_column(self):
587+
"""Test from_column missing value generation with nans."""
588+
# Setup
589+
data = pd.DataFrame({'test': [0.2, np.nan, 1.0]})
590+
transformer = LogitScaler(
591+
missing_value_generation='from_column',
592+
missing_value_replacement='mean',
593+
)
594+
595+
# Run
596+
transformer.fit(data, 'test')
597+
transformed = transformer.transform(data)
598+
reversed_values = transformer.reverse_transform(transformed)
599+
600+
# Assert
601+
np.testing.assert_array_almost_equal(reversed_values, data)
602+
603+
def test_missing_value_generation_random(self):
604+
"""Test random missing_value_generation with nans."""
605+
# Setup
606+
data = pd.DataFrame({'test': [0.2, np.nan, 1.0, 1.0]})
607+
transformer = LogitScaler(
608+
missing_value_generation='random',
609+
missing_value_replacement='mode',
610+
)
611+
expected = pd.DataFrame({'test': [0.2, np.nan, 1.0, np.nan]})
612+
613+
# Run
614+
transformer.fit(data, 'test')
615+
transformed = transformer.transform(data)
616+
reversed_values = transformer.reverse_transform(transformed)
617+
618+
# Assert
619+
np.testing.assert_array_almost_equal(reversed_values, expected)

tests/unit/transformers/test_numerical.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,13 +1892,19 @@ def test___init__(self):
18921892
def test___init___invalid_inputs(self):
18931893
"""Test super() arguments are properly passed and set as attributes."""
18941894
# Setup
1895-
min_value = 10.0
1896-
max_value = 10.0
1895+
same_min_value = 10.0
1896+
same_max_value = 10.0
1897+
bad_min_value = '10.0'
1898+
bad_max_value = (100.0,)
18971899

18981900
# Run / Assert
18991901
expected_msg = 'The min_value and max_value for the logit function cannot be equal.'
19001902
with pytest.raises(TransformerInputError, match=re.escape(expected_msg)):
1901-
LogitScaler(max_value=max_value, min_value=min_value)
1903+
LogitScaler(max_value=same_max_value, min_value=same_min_value)
1904+
1905+
expected_msg = 'The min_value and max_value must be of type int or float.'
1906+
with pytest.raises(TransformerInputError, match=re.escape(expected_msg)):
1907+
LogitScaler(max_value=bad_max_value, min_value=bad_min_value)
19021908

19031909
def test__validate_logit_inputs_with_default_settings(self):
19041910
"""Test validating data against input arguments."""

0 commit comments

Comments
 (0)