|
9 | 9 | from copulas import univariate |
10 | 10 | from pandas.api.types import is_float_dtype |
11 | 11 |
|
12 | | -from rdt.errors import TransformerInputError |
| 12 | +from rdt.errors import InvalidDataError, TransformerInputError |
13 | 13 | from rdt.transformers.null import NullTransformer |
14 | 14 | from rdt.transformers.numerical import ( |
15 | 15 | ClusterBasedNormalizer, |
16 | 16 | FloatFormatter, |
17 | 17 | GaussianNormalizer, |
| 18 | + LogScaler, |
18 | 19 | ) |
19 | 20 |
|
20 | 21 |
|
@@ -1863,3 +1864,201 @@ def test__reverse_transform_missing_value_replacement_missing_value_replacement_ |
1863 | 1864 | call_data, |
1864 | 1865 | rtol=1e-1, |
1865 | 1866 | ) |
| 1867 | + |
| 1868 | + |
| 1869 | +class TestLogScaler: |
| 1870 | + def test___init__super_attrs(self): |
| 1871 | + """super() arguments are properly passed and set as attributes.""" |
| 1872 | + ls = LogScaler( |
| 1873 | + missing_value_generation='random', |
| 1874 | + learn_rounding_scheme=False, |
| 1875 | + ) |
| 1876 | + |
| 1877 | + assert ls.missing_value_replacement == 'mean' |
| 1878 | + assert ls.missing_value_generation == 'random' |
| 1879 | + assert ls.learn_rounding_scheme is False |
| 1880 | + |
| 1881 | + def test___init__constant(self): |
| 1882 | + """Test constant parameter is set as an attribute.""" |
| 1883 | + # Setup |
| 1884 | + ls_set = LogScaler(constant=2.5) |
| 1885 | + ls_default = LogScaler() |
| 1886 | + |
| 1887 | + # Test |
| 1888 | + assert ls_set.constant == 2.5 |
| 1889 | + assert ls_default.constant == 0.0 |
| 1890 | + |
| 1891 | + def test___init__invert(self): |
| 1892 | + """Test invert parameter is set as an attribute.""" |
| 1893 | + # Setup |
| 1894 | + ls_set = LogScaler(invert=True) |
| 1895 | + ls_default = LogScaler() |
| 1896 | + |
| 1897 | + # Test |
| 1898 | + assert ls_set.invert |
| 1899 | + assert not ls_default.invert |
| 1900 | + |
| 1901 | + def test__validate_data(self): |
| 1902 | + """Test the ``_validate_data`` method""" |
| 1903 | + # Setup |
| 1904 | + ls = LogScaler() |
| 1905 | + ls.columns = ['test_col'] |
| 1906 | + valid_data = pd.Series([1, 2, 3]) |
| 1907 | + invalid_data = pd.Series([-1, 2, 4]) |
| 1908 | + message = ( |
| 1909 | + "Unable to apply a log transform to column 'test_col' due to constant being too large." |
| 1910 | + ) |
| 1911 | + # Run and Assert |
| 1912 | + ls._validate_data(valid_data) |
| 1913 | + |
| 1914 | + with pytest.raises(InvalidDataError, match=message): |
| 1915 | + ls._validate_data(invalid_data) |
| 1916 | + |
| 1917 | + def test__validate_data_invert(self): |
| 1918 | + """Test the ``_validate_data`` method""" |
| 1919 | + # Setup |
| 1920 | + ls = LogScaler(invert=True) |
| 1921 | + ls.columns = ['test'] |
| 1922 | + valid_data = pd.Series([-1, -2, -3]) |
| 1923 | + invalid_data = pd.Series([-1, 2, 4]) |
| 1924 | + message = ( |
| 1925 | + "Unable to apply a log transform to column 'test' due to constant being too small." |
| 1926 | + ) |
| 1927 | + |
| 1928 | + # Run and Assert |
| 1929 | + ls._validate_data(valid_data) |
| 1930 | + |
| 1931 | + with pytest.raises(InvalidDataError, match=message): |
| 1932 | + ls._validate_data(invalid_data) |
| 1933 | + |
| 1934 | + @patch('rdt.transformers.LogScaler._validate_data') |
| 1935 | + def test__fit(self, mock_validate): |
| 1936 | + """Test the ``_fit`` method.""" |
| 1937 | + # Setup |
| 1938 | + data = pd.Series([0.5, np.nan, 1.0]) |
| 1939 | + ls = LogScaler() |
| 1940 | + |
| 1941 | + # Run |
| 1942 | + ls._fit(data) |
| 1943 | + |
| 1944 | + # Assert |
| 1945 | + mock_validate.assert_called_once() |
| 1946 | + call_value = mock_validate.call_args_list[0] |
| 1947 | + np.testing.assert_array_equal(call_value[0][0], np.array([0.5, 0.75, 1.0])) |
| 1948 | + |
| 1949 | + def test__transform(self): |
| 1950 | + """Test the ``_transform`` method.""" |
| 1951 | + # Setup |
| 1952 | + ls = LogScaler() |
| 1953 | + ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test') |
| 1954 | + data = pd.DataFrame({'test': [0.1, 1.0, 2.0]}) |
| 1955 | + |
| 1956 | + # Run |
| 1957 | + transformed_data = ls.transform(data) |
| 1958 | + |
| 1959 | + # Assert |
| 1960 | + expected = np.array([-2.30259, 0, 0.69314]) |
| 1961 | + np.testing.assert_allclose(transformed_data['test'], expected, rtol=1e-3) |
| 1962 | + |
| 1963 | + def test__transform_invert(self): |
| 1964 | + """Test the ``_transform`` method with ``invert=True``""" |
| 1965 | + # Setup |
| 1966 | + ls = LogScaler(constant=3, invert=True) |
| 1967 | + ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test') |
| 1968 | + data = pd.DataFrame({'test': [0.1, 1.0, 2.0]}) |
| 1969 | + |
| 1970 | + # Run |
| 1971 | + transformed_data = ls.transform(data) |
| 1972 | + |
| 1973 | + # Assert |
| 1974 | + expected = np.array([1.06471, 0.69315, 0]) |
| 1975 | + np.testing.assert_allclose(transformed_data['test'], expected, rtol=1e-3) |
| 1976 | + |
| 1977 | + def test__transform_invalid_data(self): |
| 1978 | + # Setup |
| 1979 | + ls = LogScaler() |
| 1980 | + ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test') |
| 1981 | + data = pd.DataFrame({'test': [-0.1, 1.0, 2.0]}) |
| 1982 | + message = ( |
| 1983 | + "Unable to apply a log transform to column 'test' due to constant being too large." |
| 1984 | + ) |
| 1985 | + |
| 1986 | + # Run and Assert |
| 1987 | + with pytest.raises(InvalidDataError, match=message): |
| 1988 | + ls.transform(data) |
| 1989 | + |
| 1990 | + def test__transform_missing_value_generation_is_random(self): |
| 1991 | + """Test the ``_transform`` method. |
| 1992 | +
|
| 1993 | + Validate that ``_transform`` produces the correct values when ``missing_value_generation`` |
| 1994 | + is ``random``. |
| 1995 | + """ |
| 1996 | + # Setup |
| 1997 | + data = pd.Series([1.0, 2.0, 1.0]) |
| 1998 | + ls = LogScaler() |
| 1999 | + ls.columns = ['test'] |
| 2000 | + ls.null_transformer = NullTransformer('mean', missing_value_generation='random') |
| 2001 | + |
| 2002 | + # Run |
| 2003 | + ls.null_transformer.fit(data) |
| 2004 | + transformed_data = ls._transform(data) |
| 2005 | + |
| 2006 | + # Assert |
| 2007 | + expected = np.array([0, 0.69315, 0]) |
| 2008 | + np.testing.assert_allclose(transformed_data, expected, rtol=1e-3) |
| 2009 | + |
| 2010 | + def test__reverse_transform(self): |
| 2011 | + """Test the ``_reverse_transform`` method. |
| 2012 | +
|
| 2013 | + Validate that ``_reverse_transform`` produces the correct values when |
| 2014 | + ``missing_value_generation`` is 'from_column'. |
| 2015 | + """ |
| 2016 | + # Setup |
| 2017 | + data = np.array([ |
| 2018 | + [0, 0.6931471805599453, 0], |
| 2019 | + [0, 0, 1.0], |
| 2020 | + ]).T |
| 2021 | + expected = pd.Series([1.0, 2.0, np.nan]) |
| 2022 | + ls = LogScaler() |
| 2023 | + ls.null_transformer = NullTransformer( |
| 2024 | + missing_value_replacement='mean', |
| 2025 | + missing_value_generation='from_column', |
| 2026 | + ) |
| 2027 | + |
| 2028 | + # Run |
| 2029 | + ls.null_transformer.fit(expected) |
| 2030 | + transformed_data = ls._reverse_transform(data) |
| 2031 | + |
| 2032 | + # Assert |
| 2033 | + np.testing.assert_allclose(transformed_data, expected, rtol=1e-3) |
| 2034 | + |
| 2035 | + def test__reverse_transform_missing_value_generation(self): |
| 2036 | + """Test the ``_reverse_transform`` method. |
| 2037 | +
|
| 2038 | + Validate that ``_reverse_transform`` produces the correct values when |
| 2039 | + ``missing_value_generation`` is 'random'. |
| 2040 | + """ |
| 2041 | + # Setup |
| 2042 | + data = np.array([0, 0.6931471805599453, 0]) |
| 2043 | + expected = pd.Series([1.0, 2.0, 1.0]) |
| 2044 | + ls = LogScaler() |
| 2045 | + ls.null_transformer = NullTransformer(None, missing_value_generation='random') |
| 2046 | + |
| 2047 | + # Run |
| 2048 | + ls.null_transformer.fit(expected) |
| 2049 | + transformed_data = ls._reverse_transform(data) |
| 2050 | + |
| 2051 | + # Assert |
| 2052 | + np.testing.assert_allclose(transformed_data, expected, rtol=1e-3) |
| 2053 | + |
| 2054 | + def test_print(self, capsys): |
| 2055 | + """Test the class can be printed. GH#883""" |
| 2056 | + # Setup |
| 2057 | + transformer = LogScaler() |
| 2058 | + |
| 2059 | + # Run |
| 2060 | + print(transformer) # noqa: T201 `print` found |
| 2061 | + |
| 2062 | + # Assert |
| 2063 | + captured = capsys.readouterr() |
| 2064 | + assert captured.out == 'LogScaler()\n' |
0 commit comments