Skip to content

Commit f1dbfa5

Browse files
committed
LogScaler class and some tests
1 parent d6ea5d1 commit f1dbfa5

File tree

4 files changed

+315
-2
lines changed

4 files changed

+315
-2
lines changed

rdt/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ClusterBasedNormalizer,
2929
FloatFormatter,
3030
GaussianNormalizer,
31+
LogScaler,
3132
)
3233
from rdt.transformers.pii.anonymizer import (
3334
AnonymizedFaker,
@@ -46,6 +47,7 @@
4647
'FrequencyEncoder',
4748
'GaussianNormalizer',
4849
'LabelEncoder',
50+
'LogScaler',
4951
'NullTransformer',
5052
'OneHotEncoder',
5153
'OptimizedTimestampEncoder',

rdt/transformers/numerical.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
import scipy
1010

11-
from rdt.errors import TransformerInputError
11+
from rdt.errors import InvalidDataError, TransformerInputError
1212
from rdt.transformers.base import BaseTransformer
1313
from rdt.transformers.null import NullTransformer
1414
from rdt.transformers.utils import learn_rounding_digits
@@ -626,3 +626,114 @@ def _reverse_transform(self, data):
626626
recovered_data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013
627627

628628
return super()._reverse_transform(recovered_data)
629+
630+
631+
class LogScaler(FloatFormatter):
632+
"""Transformer for numerical data using log.
633+
634+
This transformer replaces integer values with their float equivalent.
635+
Non null float values are not modified.
636+
637+
Null values are replaced using a ``NullTransformer``.
638+
639+
Args:
640+
missing_value_replacement (object):
641+
Indicate what to replace the null values with. If an integer or float is given,
642+
replace them with the given value. If the strings ``'mean'`` or ``'mode'``
643+
are given, replace them with the corresponding aggregation and if ``'random'``
644+
replace each null value with a random value in the data range. Defaults to ``mean``.
645+
missing_value_generation (str or None):
646+
The way missing values are being handled. There are three strategies:
647+
648+
* ``random``: Randomly generates missing values based on the percentage of
649+
missing values.
650+
* ``from_column``: Creates a binary column that describes whether the original
651+
value was missing. Then use it to recreate missing values.
652+
* ``None``: Do nothing with the missing values on the reverse transform. Simply
653+
pass whatever data we get through.
654+
constant (float):
655+
The constant to set as the 0-value for the log-based transform. Default to 0
656+
(do not modify the 0-value of the data).
657+
invert (bool):
658+
Whether to invert the data with respect to the constant value. If False, do not
659+
invert the data (all values will be greater than the constant value). If True,
660+
invert the data (all the values will be less than the constant value).
661+
Defaults to False.
662+
learn_rounding_scheme (bool):
663+
Whether or not to learn what place to round to based on the data seen during ``fit``.
664+
If ``True``, the data returned by ``reverse_transform`` will be rounded to that place.
665+
Defaults to ``False``.
666+
"""
667+
668+
def __init__(
669+
self,
670+
missing_value_replacement='mean',
671+
missing_value_generation='random',
672+
constant: float = 0,
673+
invert: bool = False,
674+
learn_rounding_scheme: bool = False,
675+
):
676+
self.constant = constant
677+
self.invert = invert
678+
super().__init__(
679+
missing_value_replacement=missing_value_replacement,
680+
missing_value_generation=missing_value_generation,
681+
learn_rounding_scheme=learn_rounding_scheme,
682+
)
683+
684+
def _validate_data(self, data: pd.Series):
685+
column_name = self.get_input_column()
686+
if self.invert:
687+
if not all(data < self.constant):
688+
raise InvalidDataError(
689+
f"Unable to apply a log transform to column '{column_name}' due to constant"
690+
' being too small.'
691+
)
692+
else:
693+
if not all(data > self.constant):
694+
raise InvalidDataError(
695+
f"Unable to apply a log transform to column '{column_name}' due to constant"
696+
' being too large.'
697+
)
698+
699+
def _fit(self, data):
700+
super()._fit(data)
701+
data = super()._transform(data)
702+
if data.ndim > 1:
703+
self._validate_data(data[:, 0])
704+
else:
705+
self._validate_data(data)
706+
707+
def _transform(self, data):
708+
data = super()._transform(data)
709+
710+
if data.ndim > 1:
711+
self._validate_data(data[:, 0])
712+
if self.invert:
713+
data[:, 0] = np.log(self.constant - data[:, 0])
714+
else:
715+
data[:, 0] = np.log(data[:, 0] - self.constant)
716+
else:
717+
self._validate_data(data)
718+
if self.invert:
719+
data = np.log(self.constant - data)
720+
else:
721+
data = np.log(data - self.constant)
722+
return data
723+
724+
def _reverse_transform(self, data):
725+
if not isinstance(data, np.ndarray):
726+
data = data.to_numpy()
727+
728+
if data.ndim > 1:
729+
if self.invert:
730+
data[:, 0] = self.constant - np.exp(data[:, 0])
731+
else:
732+
data[:, 0] = np.exp(data[:, 0]) + self.constant
733+
else:
734+
if self.invert:
735+
data = self.constant - np.exp(data)
736+
else:
737+
data = np.exp(data) + self.constant
738+
739+
return super()._reverse_transform(data)

tests/integration/test_transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'FloatFormatter': {'missing_value_generation': 'from_column'},
2424
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
2525
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
26+
'LogScaler': {'constant': -40000000000, 'missing_value_generation': 'from_column'},
2627
}
2728

2829
# Mapping of rdt sdtype to dtype

tests/unit/transformers/test_numerical.py

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from copulas import univariate
1010
from pandas.api.types import is_float_dtype
1111

12-
from rdt.errors import TransformerInputError
12+
from rdt.errors import InvalidDataError, TransformerInputError
1313
from rdt.transformers.null import NullTransformer
1414
from rdt.transformers.numerical import (
1515
ClusterBasedNormalizer,
1616
FloatFormatter,
1717
GaussianNormalizer,
18+
LogScaler,
1819
)
1920

2021

@@ -1863,3 +1864,201 @@ def test__reverse_transform_missing_value_replacement_missing_value_replacement_
18631864
call_data,
18641865
rtol=1e-1,
18651866
)
1867+
1868+
1869+
class TestLogScaler:
1870+
def test___init__super_attrs(self):
1871+
"""super() arguments are properly passed and set as attributes."""
1872+
ls = LogScaler(
1873+
missing_value_generation='random',
1874+
learn_rounding_scheme=False,
1875+
)
1876+
1877+
assert ls.missing_value_replacement == 'mean'
1878+
assert ls.missing_value_generation == 'random'
1879+
assert ls.learn_rounding_scheme is False
1880+
1881+
def test___init__constant(self):
1882+
"""Test constant parameter is set as an attribute."""
1883+
# Setup
1884+
ls_set = LogScaler(constant=2.5)
1885+
ls_default = LogScaler()
1886+
1887+
# Test
1888+
assert ls_set.constant == 2.5
1889+
assert ls_default.constant == 0.0
1890+
1891+
def test___init__invert(self):
1892+
"""Test invert parameter is set as an attribute."""
1893+
# Setup
1894+
ls_set = LogScaler(invert=True)
1895+
ls_default = LogScaler()
1896+
1897+
# Test
1898+
assert ls_set.invert
1899+
assert not ls_default.invert
1900+
1901+
def test__validate_data(self):
1902+
"""Test the ``_validate_data`` method"""
1903+
# Setup
1904+
ls = LogScaler()
1905+
ls.columns = ['test_col']
1906+
valid_data = pd.Series([1, 2, 3])
1907+
invalid_data = pd.Series([-1, 2, 4])
1908+
message = (
1909+
"Unable to apply a log transform to column 'test_col' due to constant being too large."
1910+
)
1911+
# Run and Assert
1912+
ls._validate_data(valid_data)
1913+
1914+
with pytest.raises(InvalidDataError, match=message):
1915+
ls._validate_data(invalid_data)
1916+
1917+
def test__validate_data_invert(self):
1918+
"""Test the ``_validate_data`` method"""
1919+
# Setup
1920+
ls = LogScaler(invert=True)
1921+
ls.columns = ['test']
1922+
valid_data = pd.Series([-1, -2, -3])
1923+
invalid_data = pd.Series([-1, 2, 4])
1924+
message = (
1925+
"Unable to apply a log transform to column 'test' due to constant being too small."
1926+
)
1927+
1928+
# Run and Assert
1929+
ls._validate_data(valid_data)
1930+
1931+
with pytest.raises(InvalidDataError, match=message):
1932+
ls._validate_data(invalid_data)
1933+
1934+
@patch('rdt.transformers.LogScaler._validate_data')
1935+
def test__fit(self, mock_validate):
1936+
"""Test the ``_fit`` method."""
1937+
# Setup
1938+
data = pd.Series([0.5, np.nan, 1.0])
1939+
ls = LogScaler()
1940+
1941+
# Run
1942+
ls._fit(data)
1943+
1944+
# Assert
1945+
mock_validate.assert_called_once()
1946+
call_value = mock_validate.call_args_list[0]
1947+
np.testing.assert_array_equal(call_value[0][0], np.array([0.5, 0.75, 1.0]))
1948+
1949+
def test__transform(self):
1950+
"""Test the ``_transform`` method."""
1951+
# Setup
1952+
ls = LogScaler()
1953+
ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test')
1954+
data = pd.DataFrame({'test': [0.1, 1.0, 2.0]})
1955+
1956+
# Run
1957+
transformed_data = ls.transform(data)
1958+
1959+
# Assert
1960+
expected = np.array([-2.30259, 0, 0.69314])
1961+
np.testing.assert_allclose(transformed_data['test'], expected, rtol=1e-3)
1962+
1963+
def test__transform_invert(self):
1964+
"""Test the ``_transform`` method with ``invert=True``"""
1965+
# Setup
1966+
ls = LogScaler(constant=3, invert=True)
1967+
ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test')
1968+
data = pd.DataFrame({'test': [0.1, 1.0, 2.0]})
1969+
1970+
# Run
1971+
transformed_data = ls.transform(data)
1972+
1973+
# Assert
1974+
expected = np.array([1.06471, 0.69315, 0])
1975+
np.testing.assert_allclose(transformed_data['test'], expected, rtol=1e-3)
1976+
1977+
def test__transform_invalid_data(self):
1978+
# Setup
1979+
ls = LogScaler()
1980+
ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test')
1981+
data = pd.DataFrame({'test': [-0.1, 1.0, 2.0]})
1982+
message = (
1983+
"Unable to apply a log transform to column 'test' due to constant being too large."
1984+
)
1985+
1986+
# Run and Assert
1987+
with pytest.raises(InvalidDataError, match=message):
1988+
ls.transform(data)
1989+
1990+
def test__transform_missing_value_generation_is_random(self):
1991+
"""Test the ``_transform`` method.
1992+
1993+
Validate that ``_transform`` produces the correct values when ``missing_value_generation``
1994+
is ``random``.
1995+
"""
1996+
# Setup
1997+
data = pd.Series([1.0, 2.0, 1.0])
1998+
ls = LogScaler()
1999+
ls.columns = ['test']
2000+
ls.null_transformer = NullTransformer('mean', missing_value_generation='random')
2001+
2002+
# Run
2003+
ls.null_transformer.fit(data)
2004+
transformed_data = ls._transform(data)
2005+
2006+
# Assert
2007+
expected = np.array([0, 0.69315, 0])
2008+
np.testing.assert_allclose(transformed_data, expected, rtol=1e-3)
2009+
2010+
def test__reverse_transform(self):
2011+
"""Test the ``_reverse_transform`` method.
2012+
2013+
Validate that ``_reverse_transform`` produces the correct values when
2014+
``missing_value_generation`` is 'from_column'.
2015+
"""
2016+
# Setup
2017+
data = np.array([
2018+
[0, 0.6931471805599453, 0],
2019+
[0, 0, 1.0],
2020+
]).T
2021+
expected = pd.Series([1.0, 2.0, np.nan])
2022+
ls = LogScaler()
2023+
ls.null_transformer = NullTransformer(
2024+
missing_value_replacement='mean',
2025+
missing_value_generation='from_column',
2026+
)
2027+
2028+
# Run
2029+
ls.null_transformer.fit(expected)
2030+
transformed_data = ls._reverse_transform(data)
2031+
2032+
# Assert
2033+
np.testing.assert_allclose(transformed_data, expected, rtol=1e-3)
2034+
2035+
def test__reverse_transform_missing_value_generation(self):
2036+
"""Test the ``_reverse_transform`` method.
2037+
2038+
Validate that ``_reverse_transform`` produces the correct values when
2039+
``missing_value_generation`` is 'random'.
2040+
"""
2041+
# Setup
2042+
data = np.array([0, 0.6931471805599453, 0])
2043+
expected = pd.Series([1.0, 2.0, 1.0])
2044+
ls = LogScaler()
2045+
ls.null_transformer = NullTransformer(None, missing_value_generation='random')
2046+
2047+
# Run
2048+
ls.null_transformer.fit(expected)
2049+
transformed_data = ls._reverse_transform(data)
2050+
2051+
# Assert
2052+
np.testing.assert_allclose(transformed_data, expected, rtol=1e-3)
2053+
2054+
def test_print(self, capsys):
2055+
"""Test the class can be printed. GH#883"""
2056+
# Setup
2057+
transformer = LogScaler()
2058+
2059+
# Run
2060+
print(transformer) # noqa: T201 `print` found
2061+
2062+
# Assert
2063+
captured = capsys.readouterr()
2064+
assert captured.out == 'LogScaler()\n'

0 commit comments

Comments
 (0)