Skip to content

Commit ed11efc

Browse files
committed
Add LogScaler transformer
1 parent 2317e28 commit ed11efc

File tree

5 files changed

+509
-0
lines changed

5 files changed

+509
-0
lines changed

rdt/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ClusterBasedNormalizer,
2929
FloatFormatter,
3030
GaussianNormalizer,
31+
LogScaler,
3132
LogitScaler,
3233
)
3334
from rdt.transformers.pii.anonymizer import (
@@ -47,6 +48,7 @@
4748
'FrequencyEncoder',
4849
'GaussianNormalizer',
4950
'LabelEncoder',
51+
'LogScaler',
5052
'NullTransformer',
5153
'OneHotEncoder',
5254
'OptimizedTimestampEncoder',

rdt/transformers/numerical.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,3 +731,120 @@ def _reverse_transform(self, data):
731731

732732
data[:, 0] = reversed_values
733733
return super()._reverse_transform(data)
734+
735+
736+
class LogScaler(FloatFormatter):
737+
"""Transformer for numerical data using log.
738+
739+
This transformer scales numerical values using log and an optional constant.
740+
Null values are replaced using a ``NullTransformer``.
741+
742+
Args:
743+
missing_value_replacement (object):
744+
Indicate what to replace the null values with. If an integer or float is given,
745+
replace them with the given value. If the strings ``'mean'`` or ``'mode'``
746+
are given, replace them with the corresponding aggregation and if ``'random'``
747+
replace each null value with a random value in the data range. Defaults to ``mean``.
748+
missing_value_generation (str or None):
749+
The way missing values are being handled. There are three strategies:
750+
* ``random``: Randomly generates missing values based on the percentage of
751+
missing values.
752+
* ``from_column``: Creates a binary column that describes whether the original
753+
value was missing. Then use it to recreate missing values.
754+
* ``None``: Do nothing with the missing values on the reverse transform. Simply
755+
pass whatever data we get through.
756+
constant (float):
757+
The constant to set as the 0-value for the log-based transform. Defaults to 0
758+
(do not modify the 0-value of the data).
759+
invert (bool):
760+
Whether to invert the data with respect to the constant value. If False, do not
761+
invert the data (all values will be greater than the constant value). If True,
762+
invert the data (all the values will be less than the constant value).
763+
Defaults to False.
764+
learn_rounding_scheme (bool):
765+
Whether or not to learn what place to round to based on the data seen during ``fit``.
766+
If ``True``, the data returned by ``reverse_transform`` will be rounded to that place.
767+
Defaults to ``False``.
768+
"""
769+
770+
def __init__(
771+
self,
772+
missing_value_replacement='mean',
773+
missing_value_generation='random',
774+
constant: float = 0.0,
775+
invert: bool = False,
776+
learn_rounding_scheme: bool = False,
777+
):
778+
if isinstance(constant, (int, float)):
779+
self.constant = constant
780+
else:
781+
raise ValueError('The constant parameter must be a float or int.')
782+
if isinstance(invert, bool):
783+
self.invert = invert
784+
else:
785+
raise ValueError('The invert parameter must be a bool.')
786+
787+
super().__init__(
788+
missing_value_replacement=missing_value_replacement,
789+
missing_value_generation=missing_value_generation,
790+
learn_rounding_scheme=learn_rounding_scheme,
791+
)
792+
793+
def _validate_data(self, data: pd.Series):
794+
column_name = self.get_input_column()
795+
if self.invert:
796+
if not all(data < self.constant):
797+
raise InvalidDataError(
798+
f"Unable to apply a log transform to column '{column_name}' due to constant"
799+
' being too small.'
800+
)
801+
else:
802+
if not all(data > self.constant):
803+
raise InvalidDataError(
804+
f"Unable to apply a log transform to column '{column_name}' due to constant"
805+
' being too large.'
806+
)
807+
808+
def _fit(self, data):
809+
super()._fit(data)
810+
data = super()._transform(data)
811+
812+
if data.ndim > 1:
813+
self._validate_data(data[:, 0])
814+
else:
815+
self._validate_data(data)
816+
817+
def _log_transform(self, data):
818+
self._validate_data(data)
819+
820+
if self.invert:
821+
return np.log(self.constant - data)
822+
else:
823+
return np.log(data - self.constant)
824+
825+
def _transform(self, data):
826+
data = super()._transform(data)
827+
828+
if data.ndim > 1:
829+
data[:, 0] = self._log_transform(data[:, 0])
830+
else:
831+
data = self._log_transform(data)
832+
833+
return data
834+
835+
def _reverse_log(self, data):
836+
if self.invert:
837+
return self.constant - np.exp(data)
838+
else:
839+
return np.exp(data) + self.constant
840+
841+
def _reverse_transform(self, data):
842+
if not isinstance(data, np.ndarray):
843+
data = data.to_numpy()
844+
845+
if data.ndim > 1:
846+
data[:, 0] = self._reverse_log(data[:, 0])
847+
else:
848+
data = self._reverse_log(data)
849+
850+
return super()._reverse_transform(data)

tests/integration/test_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TEST_COL = 'test_col'
1313

1414
PRIMARY_SDTYPES = ['boolean', 'categorical', 'datetime', 'numerical']
15+
INT64_MIN = np.iinfo(np.int64).min
1516

1617
# Additional arguments for transformers
1718
TRANSFORMER_ARGS = {
@@ -24,6 +25,7 @@
2425
'FloatFormatter': {'missing_value_generation': 'from_column'},
2526
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
2627
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
28+
'LogScaler': {'constant': INT64_MIN, 'missing_value_generation': 'from_column'},
2729
'LogitScaler': {
2830
'missing_value_generation': 'from_column',
2931
'FROM_DATA': {

tests/integration/transformers/test_numerical.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
FloatFormatter,
99
GaussianNormalizer,
1010
LogitScaler,
11+
LogScaler,
1112
)
1213

1314

@@ -617,3 +618,61 @@ def test_missing_value_generation_random(self):
617618

618619
# Assert
619620
np.testing.assert_array_almost_equal(reversed_values, expected)
621+
622+
623+
class TestLogScaler:
624+
def test_learn_rounding(self):
625+
"""Test that transformer learns rounding scheme from data."""
626+
# Setup
627+
data = pd.DataFrame({'test': [1.0, np.nan, 1.5]})
628+
transformer = LogScaler(
629+
missing_value_generation=None,
630+
missing_value_replacement='mean',
631+
learn_rounding_scheme=True,
632+
)
633+
expected = pd.DataFrame({'test': [1.0, 1.2, 1.5]})
634+
635+
# Run
636+
transformer.fit(data, 'test')
637+
transformed = transformer.transform(data)
638+
reversed_values = transformer.reverse_transform(transformed)
639+
640+
# Assert
641+
np.testing.assert_array_equal(reversed_values, expected)
642+
643+
def test_missing_value_generation_from_column(self):
644+
"""Test from_column missing value generation with nans present."""
645+
# Setup
646+
data = pd.DataFrame({'test': [1.0, np.nan, 1.5]})
647+
transformer = LogScaler(
648+
missing_value_generation='from_column',
649+
missing_value_replacement='mean',
650+
)
651+
652+
# Run
653+
transformer.fit(data, 'test')
654+
transformed = transformer.transform(data)
655+
reversed_values = transformer.reverse_transform(transformed)
656+
657+
# Assert
658+
np.testing.assert_array_equal(reversed_values, data)
659+
660+
def test_missing_value_generation_random(self):
661+
"""Test random missing_value_generation with nans present."""
662+
# Setup
663+
data = pd.DataFrame({'test': [1.0, np.nan, 1.5, 1.5]})
664+
transformer = LogScaler(
665+
missing_value_generation='random',
666+
missing_value_replacement='mode',
667+
invert=True,
668+
constant=3.0,
669+
)
670+
expected = pd.DataFrame({'test': [np.nan, 1.5, 1.5, 1.5]})
671+
672+
# Run
673+
transformer.fit(data, 'test')
674+
transformed = transformer.transform(data)
675+
reversed_values = transformer.reverse_transform(transformed)
676+
677+
# Assert
678+
np.testing.assert_array_equal(reversed_values, expected)

0 commit comments

Comments
 (0)