Skip to content

Commit 07ac708

Browse files
committed
Add LogitScaler transformer
1 parent d6ea5d1 commit 07ac708

File tree

5 files changed

+367
-3
lines changed

5 files changed

+367
-3
lines changed

rdt/transformers/numerical.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import pandas as pd
99
import scipy
1010

11-
from rdt.errors import TransformerInputError
11+
from rdt.errors import InvalidDataError, TransformerInputError
1212
from rdt.transformers.base import BaseTransformer
1313
from rdt.transformers.null import NullTransformer
14-
from rdt.transformers.utils import learn_rounding_digits
14+
from rdt.transformers.utils import learn_rounding_digits, logit, sigmoid
1515

1616
EPSILON = np.finfo(np.float32).eps
1717
INTEGER_BOUNDS = {
@@ -626,3 +626,85 @@ def _reverse_transform(self, data):
626626
recovered_data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013
627627

628628
return super()._reverse_transform(recovered_data)
629+
630+
631+
class LogitScaler(FloatFormatter):
632+
"""Transformer for numerical data by applying a logit function.
633+
634+
This transformer works by replacing the values with a scaled
635+
version and then applying a logit function. The reverse transform
636+
applies a sigmoid to the data and then scales it back to the original space.
637+
638+
Null values are replaced using a ``NullTransformer``.
639+
640+
Args:
641+
missing_value_replacement (object):
642+
Indicate what to replace the null values with. If an integer or float is given,
643+
replace them with the given value. If the strings ``'mean'`` or ``'mode'``
644+
are given, replace them with the corresponding aggregation and if ``'random'``
645+
replace each null value with a random value in the data range. Defaults to ``mean``.
646+
missing_value_generation (str or None):
647+
The way missing values are being handled. There are three strategies:
648+
649+
* ``random``: Randomly generates missing values based on the percentage of
650+
missing values.
651+
* ``from_column``: Creates a binary column that describes whether the original
652+
value was missing. Then use it to recreate missing values.
653+
* ``None``: Do nothing with the missing values on the reverse transform. Simply
654+
pass whatever data we get through.
655+
min_value (float):
656+
The min value for the logit function. Defaults to 0.
657+
max_value (float):
658+
max_value (float): The max value for the logit function. Defaults to 1.0.
659+
learn_rounding_scheme (bool):
660+
Whether or not to learn what place to round to based on the data seen during ``fit``.
661+
If ``True``, the data returned by ``reverse_transform`` will be rounded to that place.
662+
Defaults to ``False``.
663+
"""
664+
665+
def __init__(
666+
self,
667+
missing_value_replacement='mean',
668+
missing_value_generation='random',
669+
min_value=0.0,
670+
max_value=1.0,
671+
learn_rounding_scheme=False,
672+
):
673+
super().__init__(
674+
missing_value_replacement=missing_value_replacement,
675+
missing_value_generation=missing_value_generation,
676+
learn_rounding_scheme=learn_rounding_scheme,
677+
)
678+
self.min_value = min_value
679+
self.max_value = max_value
680+
681+
def _validate_logit_inputs(self, data):
682+
out_of_range_vals = data[(data < self.min_value) | (data > self.max_value)]
683+
if len(out_of_range_vals) > 0:
684+
num_vals_to_print = 5
685+
out_of_range_vals = [str(x) for x in sorted(out_of_range_vals, key=lambda x: str(x))]
686+
if len(out_of_range_vals) > 5:
687+
extra_missing_vals = f'+ {len(out_of_range_vals) - num_vals_to_print} more'
688+
out_of_range_vals = (
689+
f'[{", ".join(out_of_range_vals[:num_vals_to_print])} {extra_missing_vals}]'
690+
)
691+
else:
692+
out_of_range_vals = f'[{", ".join(out_of_range_vals)}]'
693+
694+
raise InvalidDataError(
695+
f"Unable to apply logit function to column '{self.columns[0]}' due to out of "
696+
f'range values ({out_of_range_vals}).'
697+
)
698+
699+
def _fit(self, data):
700+
self._validate_logit_inputs(data)
701+
return super()._fit(data)
702+
703+
def _transform(self, data):
704+
transformed = super()._transform(data)
705+
self._validate_logit_inputs(transformed)
706+
return logit(transformed, self.min_value, self.max_value)
707+
708+
def _reverse_transform(self, data):
709+
reversed = sigmoid(data, self.min_value, self.max_value)
710+
return super()._reverse_transform(reversed)

rdt/transformers/utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import warnings
88
from collections import defaultdict
9+
from decimal import Decimal
910

1011
import numpy as np
1112
import pandas as pd
@@ -126,6 +127,17 @@ def _from_generators(generators, max_repeat):
126127
yield ''.join(reversed(generated))
127128

128129

130+
def _cast_to_type(data, dtype):
131+
if isinstance(data, pd.Series):
132+
data = data.apply(dtype)
133+
elif isinstance(data, (np.ndarray, list)):
134+
data = np.array([dtype(value) for value in data])
135+
else:
136+
data = dtype(data)
137+
138+
return data
139+
140+
129141
def strings_from_regex(regex, max_repeat=16):
130142
"""Generate strings that match the given regular expression.
131143
@@ -280,6 +292,50 @@ def learn_rounding_digits(data):
280292
return None
281293

282294

295+
def logit(data, low, high):
296+
"""Apply a logit function to the data using ``low`` and ``high``.
297+
298+
Args:
299+
data (pd.Series, pd.DataFrame, np.array, int, or float):
300+
Data to apply the logit function to.
301+
low (pd.Series, np.array, int, or float):
302+
Low value/s to use when scaling.
303+
high (pd.Series, np.array, int, or float):
304+
High value/s to use when scaling.
305+
306+
Returns:
307+
Logit scaled version of the input data.
308+
"""
309+
data = (data - low) / (high - low)
310+
data = _cast_to_type(data, Decimal)
311+
data = data * Decimal(0.95) + Decimal(0.025)
312+
data = _cast_to_type(data, float)
313+
return np.log(data / (1.0 - data))
314+
315+
316+
def sigmoid(data, low, high):
317+
"""Apply a sigmoid function to the data using ``low`` and ``high``.
318+
319+
Args:
320+
data (pd.Series, pd.DataFrame, np.array, int, float or datetime):
321+
Data to apply the logit function to.
322+
low (pd.Series, np.array, int, float or datetime):
323+
Low value/s to use when scaling.
324+
high (pd.Series, np.array, int, float or datetime):
325+
High value/s to use when scaling.
326+
327+
Returns:
328+
Sigmoid transform of the input data.
329+
"""
330+
data = 1 / (1 + np.exp(-data))
331+
data = _cast_to_type(data, Decimal)
332+
data = (data - Decimal(0.025)) / Decimal(0.95)
333+
data = _cast_to_type(data, float)
334+
data = data * (high - low) + low
335+
336+
return data
337+
338+
283339
class WarnDict(dict):
284340
"""Custom dictionary to raise a deprecation warning."""
285341

tests/integration/test_transformers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
22

3+
import numpy as np
34
import pandas as pd
45
import pytest
56

@@ -23,6 +24,12 @@
2324
'FloatFormatter': {'missing_value_generation': 'from_column'},
2425
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
2526
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
27+
'LogitScaler': {
28+
'FROM_DATA': {
29+
'min_value': lambda x: np.nanmin(x) - 1,
30+
'max_value': lambda x: np.nanmax(x) + 1,
31+
}
32+
},
2633
}
2734

2835
# Mapping of rdt sdtype to dtype
@@ -149,6 +156,12 @@ def _test_transformer_with_dataset(transformer_class, input_data, steps):
149156
"""
150157

151158
transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
159+
if 'FROM_DATA' in transformer_args:
160+
transformer_args = {**transformer_args}
161+
args = transformer_args.pop('FROM_DATA')
162+
for arg, arg_func in args.items():
163+
transformer_args[arg] = arg_func(input_data[TEST_COL])
164+
152165
transformer = transformer_class(**transformer_args)
153166
# Fit
154167
transformer.fit(input_data, [TEST_COL])
@@ -203,6 +216,12 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps
203216
transformer_args = TRANSFORMER_ARGS.get(transformer_class.__name__, {})
204217
hypertransformer = HyperTransformer()
205218
if transformer_args:
219+
if 'FROM_DATA' in transformer_args:
220+
transformer_args = {**transformer_args}
221+
args = transformer_args.pop('FROM_DATA')
222+
for arg, arg_func in args.items():
223+
transformer_args[arg] = arg_func(input_data[TEST_COL])
224+
206225
field_transformers = {TEST_COL: transformer_class(**transformer_args)}
207226

208227
else:

tests/unit/transformers/test_numerical.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from copulas import univariate
1010
from pandas.api.types import is_float_dtype
1111

12-
from rdt.errors import TransformerInputError
12+
from rdt.errors import InvalidDataError, TransformerInputError
1313
from rdt.transformers.null import NullTransformer
1414
from rdt.transformers.numerical import (
1515
ClusterBasedNormalizer,
1616
FloatFormatter,
1717
GaussianNormalizer,
18+
LogitScaler,
1819
)
1920

2021

@@ -1863,3 +1864,120 @@ def test__reverse_transform_missing_value_replacement_missing_value_replacement_
18631864
call_data,
18641865
rtol=1e-1,
18651866
)
1867+
1868+
1869+
class TestLogitScaler:
1870+
def test___init__super_attrs(self):
1871+
"""Test super() arguments are properly passed and set as attributes."""
1872+
# Run
1873+
ls = LogitScaler(
1874+
missing_value_generation='random',
1875+
learn_rounding_scheme=False,
1876+
)
1877+
1878+
# Assert
1879+
assert ls.missing_value_replacement == 'mean'
1880+
assert ls.missing_value_generation == 'random'
1881+
assert ls.learn_rounding_scheme is False
1882+
1883+
def test___init__(self):
1884+
"""Test super() arguments are properly passed and set as attributes."""
1885+
# Run
1886+
ls = LogitScaler(max_value=100.0, min_value=2.0)
1887+
1888+
# Assert
1889+
assert ls.max_value == 100.0
1890+
assert ls.min_value == 2.0
1891+
1892+
def test__validate_logit_inputs(self):
1893+
"""Test validating data against input arguments."""
1894+
# Setup
1895+
ls = LogitScaler()
1896+
data = pd.Series([0.0, 0.1, 0.2, 0.3, 1.0])
1897+
1898+
# Run and Assert
1899+
ls._validate_logit_inputs(data)
1900+
1901+
def test__validate_logit_inputs_errors_invalid_value(self):
1902+
"""Test error message contains invalid values."""
1903+
# Setup
1904+
ls = LogitScaler()
1905+
ls.columns = ['column']
1906+
data = pd.Series([0.0, 0.1, 0.2, 0.3, 1.0, 2.0])
1907+
1908+
# Run and Assert
1909+
expected_msg = re.escape(
1910+
"Unable to apply logit function to column 'column' due to out of range values ([2.0])."
1911+
)
1912+
with pytest.raises(InvalidDataError, match=expected_msg):
1913+
ls._validate_logit_inputs(data)
1914+
1915+
def test__validate_logit_inputs_errors_many_invalid_values(self):
1916+
"""Test error message clips many invalid values."""
1917+
# Setup
1918+
ls = LogitScaler()
1919+
ls.columns = ['column']
1920+
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
1921+
1922+
# Run and Assert
1923+
expected_msg = re.escape(
1924+
"Unable to apply logit function to column 'column' due to out of range values "
1925+
'([1.1, 1.2, 1.3, 2.0, 3.0 + 1 more]).'
1926+
)
1927+
with pytest.raises(InvalidDataError, match=expected_msg):
1928+
ls._validate_logit_inputs(data)
1929+
1930+
def test__fit(self):
1931+
"""Test the ``_fit`` method validates the inputs."""
1932+
# Setup
1933+
ls = LogitScaler()
1934+
ls._validate_logit_inputs = Mock()
1935+
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
1936+
1937+
# Run
1938+
ls._fit(data)
1939+
1940+
# Assert
1941+
ls._validate_logit_inputs.assert_called_once_with(data)
1942+
1943+
@patch('rdt.transformers.numerical.logit')
1944+
def test__transform(self, mock_logit):
1945+
"""Test the ``transform`` method."""
1946+
# Setup
1947+
min_value = (1.0,)
1948+
max_value = 50.0
1949+
ls = LogitScaler(min_value=min_value, max_value=max_value)
1950+
ls._validate_logit_inputs = Mock()
1951+
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
1952+
null_transformer_mock = Mock()
1953+
null_transformer_mock.transform.return_value = data
1954+
ls.null_transformer = null_transformer_mock
1955+
1956+
# Run
1957+
transformed = ls._transform(data)
1958+
1959+
# Assert
1960+
ls._validate_logit_inputs.assert_called_once_with(data)
1961+
mock_logit.assert_called_once_with(data, ls.min_value, ls.max_value)
1962+
assert transformed == mock_logit.return_value
1963+
1964+
@patch('rdt.transformers.numerical.FloatFormatter._reverse_transform')
1965+
@patch('rdt.transformers.numerical.sigmoid')
1966+
def test__reverse_transform(self, mock_sigmoid, ff_reverse_transform_mock):
1967+
"""Test the ``transform`` method."""
1968+
# Setup
1969+
min_value = (1.0,)
1970+
max_value = 50.0
1971+
ls = LogitScaler(min_value=min_value, max_value=max_value)
1972+
data = pd.Series([1.0, 1.1, 1.2, 1.3, 2.0, 3.0, 4.0])
1973+
null_transformer_mock = Mock()
1974+
null_transformer_mock.reverse_transform.return_value = data
1975+
ls.null_transformer = null_transformer_mock
1976+
1977+
# Run
1978+
reversed = ls._reverse_transform(data)
1979+
1980+
# Assert
1981+
mock_sigmoid.assert_called_once_with(data, ls.min_value, ls.max_value)
1982+
ff_reverse_transform_mock.assert_called_once_with(mock_sigmoid.return_value)
1983+
assert reversed == ff_reverse_transform_mock.return_value

0 commit comments

Comments
 (0)