|
| 1 | +import tempfile |
| 2 | +import pytest |
| 3 | +import numpy as np |
| 4 | +from pandas import DataFrame |
| 5 | +import joblib |
| 6 | + |
| 7 | +from sklearn_pandas import DataFrameMapper |
| 8 | +from sklearn_pandas.transformers import NumericalTransformer |
| 9 | + |
| 10 | + |
| 11 | +@pytest.fixture |
| 12 | +def simple_dataset(): |
| 13 | + return DataFrame({ |
| 14 | + 'feat1': [1, 2, 1, 3, 1], |
| 15 | + 'feat2': [1, 2, 2, 2, 3], |
| 16 | + 'feat3': [1, 2, 3, 4, 5], |
| 17 | + }) |
| 18 | + |
| 19 | + |
| 20 | +def test_common_numerical_transformer(simple_dataset): |
| 21 | + """ |
| 22 | + Test log transformation |
| 23 | + """ |
| 24 | + transfomer = DataFrameMapper([ |
| 25 | + ('feat1', NumericalTransformer('log')) |
| 26 | + ], df_out=True) |
| 27 | + df = simple_dataset |
| 28 | + outDF = transfomer.fit_transform(df) |
| 29 | + assert list(outDF.columns) == ['feat1'] |
| 30 | + assert np.array_equal(df['feat1'].apply(np.log).values, outDF.feat1.values) |
| 31 | + |
| 32 | + |
| 33 | +def test_numerical_transformer_serialization(simple_dataset): |
| 34 | + """ |
| 35 | + Test if you can serialize transformer |
| 36 | + """ |
| 37 | + transfomer = DataFrameMapper([ |
| 38 | + ('feat1', NumericalTransformer('log')) |
| 39 | + ]) |
| 40 | + |
| 41 | + df = simple_dataset |
| 42 | + transfomer.fit(df) |
| 43 | + f = tempfile.NamedTemporaryFile(delete=True) |
| 44 | + joblib.dump(transfomer, f.name) |
| 45 | + transfomer2 = joblib.load(f.name) |
| 46 | + np.array_equal(transfomer.transform(df), transfomer2.transform(df)) |
| 47 | + f.close() |
0 commit comments