Skip to content

Commit 661bd94

Browse files
author
RJ Agrawal
committed
merged NumericalTransformer
1 parent eac4674 commit 661bd94

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

tests/test_transformers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import tempfile
2+
import pytest
3+
import numpy as np
4+
from pandas import DataFrame
5+
import joblib
6+
7+
from sklearn_pandas import DataFrameMapper
8+
from sklearn_pandas.transformers import NumericalTransformer
9+
10+
11+
@pytest.fixture
12+
def simple_dataset():
13+
return DataFrame({
14+
'feat1': [1, 2, 1, 3, 1],
15+
'feat2': [1, 2, 2, 2, 3],
16+
'feat3': [1, 2, 3, 4, 5],
17+
})
18+
19+
20+
def test_common_numerical_transformer(simple_dataset):
21+
"""
22+
Test log transformation
23+
"""
24+
transfomer = DataFrameMapper([
25+
('feat1', NumericalTransformer('log'))
26+
], df_out=True)
27+
df = simple_dataset
28+
outDF = transfomer.fit_transform(df)
29+
assert list(outDF.columns) == ['feat1']
30+
assert np.array_equal(df['feat1'].apply(np.log).values, outDF.feat1.values)
31+
32+
33+
def test_numerical_transformer_serialization(simple_dataset):
34+
"""
35+
Test if you can serialize transformer
36+
"""
37+
transfomer = DataFrameMapper([
38+
('feat1', NumericalTransformer('log'))
39+
])
40+
41+
df = simple_dataset
42+
transfomer.fit(df)
43+
f = tempfile.NamedTemporaryFile(delete=True)
44+
joblib.dump(transfomer, f.name)
45+
transfomer2 = joblib.load(f.name)
46+
np.array_equal(transfomer.transform(df), transfomer2.transform(df))
47+
f.close()

0 commit comments

Comments
 (0)