Skip to content

Commit e06247d

Browse files
committed
Fix broken tests
1 parent 07ac708 commit e06247d

File tree

4 files changed

+77
-8
lines changed

4 files changed

+77
-8
lines changed

rdt/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ClusterBasedNormalizer,
2929
FloatFormatter,
3030
GaussianNormalizer,
31+
LogitScaler,
3132
)
3233
from rdt.transformers.pii.anonymizer import (
3334
AnonymizedFaker,

rdt/transformers/numerical.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,23 @@ def _fit(self, data):
702702

703703
def _transform(self, data):
704704
transformed = super()._transform(data)
705-
self._validate_logit_inputs(transformed)
706-
return logit(transformed, self.min_value, self.max_value)
705+
transformed_vals = transformed if transformed.ndim == 1 else transformed[:, 0]
706+
self._validate_logit_inputs(transformed_vals)
707+
logit_vals = logit(transformed_vals, self.min_value, self.max_value)
708+
if transformed.ndim == 1:
709+
return logit_vals
710+
else:
711+
transformed[:, 0] = logit_vals
712+
return transformed
707713

708714
def _reverse_transform(self, data):
709-
reversed = sigmoid(data, self.min_value, self.max_value)
710-
return super()._reverse_transform(reversed)
715+
if not isinstance(data, np.ndarray):
716+
data = data.to_numpy()
717+
718+
sampled_vals = data if data.ndim == 1 else data[:, 0]
719+
reversed = sigmoid(sampled_vals, self.min_value, self.max_value)
720+
if data.ndim == 1:
721+
return super()._reverse_transform(reversed)
722+
else:
723+
data[:, 0] = reversed
724+
return super()._reverse_transform(data)

tests/integration/test_transformers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
2626
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
2727
'LogitScaler': {
28+
'missing_value_generation': 'from_column',
2829
'FROM_DATA': {
29-
'min_value': lambda x: np.nanmin(x) - 1,
30-
'max_value': lambda x: np.nanmax(x) + 1,
31-
}
30+
'min_value': lambda x: np.nanmin(x) - 0.01,
31+
'max_value': lambda x: np.nanmax(x) + 0.01,
32+
},
3233
},
3334
}
3435

tests/unit/transformers/test_numerical.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1961,6 +1961,28 @@ def test__transform(self, mock_logit):
19611961
mock_logit.assert_called_once_with(data, ls.min_value, ls.max_value)
19621962
assert transformed == mock_logit.return_value
19631963

1964+
@patch('rdt.transformers.numerical.logit')
1965+
def test__transform_multi_column(self, mock_logit):
1966+
"""Test the ``transform`` method with multiple columns."""
1967+
# Setup
1968+
min_value = (1.0,)
1969+
max_value = 50.0
1970+
ls = LogitScaler(min_value=min_value, max_value=max_value)
1971+
ls._validate_logit_inputs = Mock()
1972+
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
1973+
null_transformer_mock = Mock()
1974+
is_null = np.array([0, 0, 0, 1, 0, 1, 0])
1975+
null_transformer_mock.transform.return_value = np.array([data.to_numpy(), is_null]).T
1976+
ls.null_transformer = null_transformer_mock
1977+
logit_values = np.array([0.0, 0.1, 0.2, 0.3, 0.3, 1.4, 2.5])
1978+
mock_logit.return_value = logit_values
1979+
1980+
# Run
1981+
transformed = ls._transform(data)
1982+
1983+
# Assert
1984+
np.testing.assert_array_equal(transformed, np.array([logit_values, is_null]).T)
1985+
19641986
@patch('rdt.transformers.numerical.FloatFormatter._reverse_transform')
19651987
@patch('rdt.transformers.numerical.sigmoid')
19661988
def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
@@ -1978,6 +2000,37 @@ def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
19782000
reversed = ls._reverse_transform(data)
19792001

19802002
# Assert
1981-
mock_sigmoid.assert_called_once_with(data, ls.min_value, ls.max_value)
2003+
mock_sigmoid_args = mock_sigmoid.call_args[0]
2004+
np.testing.assert_array_equal(mock_sigmoid_args[0], data.to_numpy())
2005+
assert mock_sigmoid_args[1] == ls.min_value
2006+
assert mock_sigmoid_args[2] == ls.max_value
19822007
ff_reverse_transform_mock.assert_called_once_with(mock_sigmoid.return_value)
19832008
assert reversed == ff_reverse_transform_mock.return_value
2009+
2010+
@patch('rdt.transformers.numerical.FloatFormatter._reverse_transform')
2011+
@patch('rdt.transformers.numerical.sigmoid')
2012+
def test__reverse_transform_multi_column(self, mock_sigmoid, ff_reverse_transform_mock):
2013+
"""Test the ``transform`` method with multiple columns."""
2014+
# Setup
2015+
min_value = (1.0,)
2016+
max_value = 50.0
2017+
ls = LogitScaler(min_value=min_value, max_value=max_value)
2018+
sampled_data = np.array([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
2019+
is_null = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
2020+
data = pd.DataFrame({'column': sampled_data, 'column.is_null': is_null})
2021+
null_transformer_mock = Mock()
2022+
reversed = np.array([1.0, 1.1, np.nan, np.nan, 2.0, np.nan, np.nan])
2023+
null_transformer_mock.reverse_transform.return_value = reversed
2024+
ls.null_transformer = null_transformer_mock
2025+
sigmoid_vals = np.array([3.0, 3.1, 3.3, 3.4, 2.1, 4.0, 4.6])
2026+
mock_sigmoid.return_value = sigmoid_vals
2027+
2028+
# Run
2029+
reversed = ls._reverse_transform(data)
2030+
2031+
# Assert
2032+
ff_reverse_transform_args = ff_reverse_transform_mock.call_args[0]
2033+
np.testing.assert_array_equal(
2034+
ff_reverse_transform_args[0], np.array([sigmoid_vals, is_null]).T
2035+
)
2036+
assert reversed == ff_reverse_transform_mock.return_value

0 commit comments

Comments
 (0)