Skip to content

Commit fc272bf

Browse files
committed
Comments
1 parent e06247d commit fc272bf

File tree

3 files changed

+69
-30
lines changed

3 files changed

+69
-30
lines changed

rdt/transformers/numerical.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,10 @@ def __init__(
670670
max_value=1.0,
671671
learn_rounding_scheme=False,
672672
):
673+
if min_value == max_value:
674+
error_msg = 'The min_value and max_value for the logit function cannot be equal.'
675+
raise TransformerInputError(error_msg)
676+
673677
super().__init__(
674678
missing_value_replacement=missing_value_replacement,
675679
missing_value_generation=missing_value_generation,
@@ -707,18 +711,18 @@ def _transform(self, data):
707711
logit_vals = logit(transformed_vals, self.min_value, self.max_value)
708712
if transformed.ndim == 1:
709713
return logit_vals
710-
else:
711-
transformed[:, 0] = logit_vals
712-
return transformed
714+
715+
transformed[:, 0] = logit_vals
716+
return transformed
713717

714718
def _reverse_transform(self, data):
715719
if not isinstance(data, np.ndarray):
716720
data = data.to_numpy()
717721

718722
sampled_vals = data if data.ndim == 1 else data[:, 0]
719-
reversed = sigmoid(sampled_vals, self.min_value, self.max_value)
723+
reversed_values = sigmoid(sampled_vals, self.min_value, self.max_value)
720724
if data.ndim == 1:
721-
return super()._reverse_transform(reversed)
722-
else:
723-
data[:, 0] = reversed
724-
return super()._reverse_transform(data)
725+
return super()._reverse_transform(reversed_values)
726+
727+
data[:, 0] = reversed_values
728+
return super()._reverse_transform(data)

tests/integration/test_transformers.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,28 @@
4747
}
4848

4949

50+
def _create_transformer_args_from_data(transformer_args, data):
51+
"""Helper to extract transformer arguments that are data-dependent.
52+
53+
Args:
54+
transformer_args (dict):
55+
The transformer arguments.
56+
data (pd.Series):
57+
The data for the transformer.
58+
59+
Returns:
60+
dict:
61+
The transformer arguments with data-specific arguments added.
62+
"""
63+
if 'FROM_DATA' in transformer_args:
64+
transformer_args = {**transformer_args}
65+
args = transformer_args.pop('FROM_DATA')
66+
for arg, arg_func in args.items():
67+
transformer_args[arg] = arg_func(data)
68+
69+
return transformer_args
70+
71+
5072
def _validate_helper(validator_function, args, steps):
5173
"""Wrap around validation functions to either return a boolean or assert.
5274
@@ -157,11 +179,7 @@ def _test_transformer_with_dataset(transformer_class, input_data, steps):
157179
"""
158180

159181
transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
160-
if 'FROM_DATA' in transformer_args:
161-
transformer_args = {**transformer_args}
162-
args = transformer_args.pop('FROM_DATA')
163-
for arg, arg_func in args.items():
164-
transformer_args[arg] = arg_func(input_data[TEST_COL])
182+
transformer_args = _create_transformer_args_from_data(transformer_args, input_data[TEST_COL])
165183

166184
transformer = transformer_class(**transformer_args)
167185
# Fit
@@ -217,12 +235,9 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps
217235
transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
218236
hypertransformer = HyperTransformer()
219237
if transformer_args:
220-
if 'FROM_DATA' in transformer_args:
221-
transformer_args = {**transformer_args}
222-
args = transformer_args.pop('FROM_DATA')
223-
for arg, arg_func in args.items():
224-
transformer_args[arg] = arg_func(input_data[TEST_COL])
225-
238+
transformer_args = _create_transformer_args_from_data(
239+
transformer_args, input_data[TEST_COL]
240+
)
226241
field_transformers = {TEST_COL: transformer_class(**transformer_args)}
227242

228243
else:

tests/unit/transformers/test_numerical.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,7 +1889,18 @@ def test___init__(self):
18891889
assert ls.max_value == 100.0
18901890
assert ls.min_value == 2.0
18911891

1892-
def test__validate_logit_inputs(self):
1892+
def test___init___invalid_inputs(self):
1893+
"""Test super() arguments are properly passed and set as attributes."""
1894+
# Setup
1895+
min_value = 10.0
1896+
max_value = 10.0
1897+
1898+
# Run / Assert
1899+
expected_msg = 'The min_value and max_value for the logit function cannot be equal.'
1900+
with pytest.raises(TransformerInputError, match=re.escape(expected_msg)):
1901+
LogitScaler(max_value=max_value, min_value=min_value)
1902+
1903+
def test__validate_logit_inputs_with_default_settings(self):
18931904
"""Test validating data against input arguments."""
18941905
# Setup
18951906
ls = LogitScaler()
@@ -1898,6 +1909,15 @@ def test__validate_logit_inputs(self):
18981909
# Run and Assert
18991910
ls._validate_logit_inputs(data)
19001911

1912+
def test__validate_logit_inputs_with_custom_inputs(self):
1913+
"""Test validating data against input arguments."""
1914+
# Setup
1915+
ls = LogitScaler(min_value=0, max_value=100)
1916+
data = pd.Series([0.0, 10.1, 20.2, 30.3, 100])
1917+
1918+
# Run and Assert
1919+
ls._validate_logit_inputs(data)
1920+
19011921
def test__validate_logit_inputs_errors_invalid_value(self):
19021922
"""Test error message contains invalid values."""
19031923
# Setup
@@ -1944,7 +1964,7 @@ def test__fit(self):
19441964
def test__transform(self, mock_logit):
19451965
"""Test the ``transform`` method."""
19461966
# Setup
1947-
min_value = (1.0,)
1967+
min_value = 1.0
19481968
max_value = 50.0
19491969
ls = LogitScaler(min_value=min_value, max_value=max_value)
19501970
ls._validate_logit_inputs = Mock()
@@ -1965,7 +1985,7 @@ def test__transform(self, mock_logit):
19651985
def test__transform_multi_column(self, mock_logit):
19661986
"""Test the ``transform`` method with multiple columns."""
19671987
# Setup
1968-
min_value = (1.0,)
1988+
min_value = 1.0
19691989
max_value = 50.0
19701990
ls = LogitScaler(min_value=min_value, max_value=max_value)
19711991
ls._validate_logit_inputs = Mock()
@@ -1988,7 +2008,7 @@ def test__transform_multi_column(self, mock_logit):
19882008
def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
19892009
"""Test the ``transform`` method."""
19902010
# Setup
1991-
min_value = (1.0,)
2011+
min_value = 1.0
19922012
max_value = 50.0
19932013
ls = LogitScaler(min_value=min_value, max_value=max_value)
19942014
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
@@ -1997,40 +2017,40 @@ def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
19972017
ls.null_transformer = null_transformer_mock
19982018

19992019
# Run
2000-
reversed = ls._reverse_transform(data)
2020+
reversed_values = ls._reverse_transform(data)
20012021

20022022
# Assert
20032023
mock_sigmoid_args = mock_sigmoid.call_args[0]
20042024
np.testing.assert_array_equal(mock_sigmoid_args[0], data.to_numpy())
20052025
assert mock_sigmoid_args[1] == ls.min_value
20062026
assert mock_sigmoid_args[2] == ls.max_value
20072027
ff_reverse_transform_mock.assert_called_once_with(mock_sigmoid.return_value)
2008-
assert reversed == ff_reverse_transform_mock.return_value
2028+
assert reversed_values == ff_reverse_transform_mock.return_value
20092029

20102030
@patch('rdt.transformers.numerical.FloatFormatter._reverse_transform')
20112031
@patch('rdt.transformers.numerical.sigmoid')
20122032
def test__reverse_transform_multi_column(self, mock_sigmoid, ff_reverse_transform_mock):
20132033
"""Test the ``transform`` method with multiple columns."""
20142034
# Setup
2015-
min_value = (1.0,)
2035+
min_value = 1.0
20162036
max_value = 50.0
20172037
ls = LogitScaler(min_value=min_value, max_value=max_value)
20182038
sampled_data = np.array([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
20192039
is_null = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
20202040
data = pd.DataFrame({'column': sampled_data, 'column.is_null': is_null})
20212041
null_transformer_mock = Mock()
2022-
reversed = np.array([1.0, 1.1, np.nan, np.nan, 2.0, np.nan, np.nan])
2023-
null_transformer_mock.reverse_transform.return_value = reversed
2042+
reversed_values = np.array([1.0, 1.1, np.nan, np.nan, 2.0, np.nan, np.nan])
2043+
null_transformer_mock.reverse_transform.return_value = reversed_values
20242044
ls.null_transformer = null_transformer_mock
20252045
sigmoid_vals = np.array([3.0, 3.1, 3.3, 3.4, 2.1, 4.0, 4.6])
20262046
mock_sigmoid.return_value = sigmoid_vals
20272047

20282048
# Run
2029-
reversed = ls._reverse_transform(data)
2049+
reversed_values = ls._reverse_transform(data)
20302050

20312051
# Assert
20322052
ff_reverse_transform_args = ff_reverse_transform_mock.call_args[0]
20332053
np.testing.assert_array_equal(
20342054
ff_reverse_transform_args[0], np.array([sigmoid_vals, is_null]).T
20352055
)
2036-
assert reversed == ff_reverse_transform_mock.return_value
2056+
assert reversed_values == ff_reverse_transform_mock.return_value

0 commit comments

Comments
 (0)