@@ -63,6 +63,35 @@ def sklearn_model(train_data):
63
63
return model
64
64
65
65
66
+ @pytest .fixture
67
+ def sklearn_pipeline (train_data ):
68
+ from sklearn .pipeline import Pipeline
69
+ from sklearn .ensemble import GradientBoostingClassifier
70
+ from sklearn .preprocessing import StandardScaler
71
+ from sklearn .impute import SimpleImputer
72
+ from sklearn .compose import ColumnTransformer
73
+
74
+ X , y = train_data
75
+
76
+ numeric_transformer = Pipeline ([
77
+ ('imputer' , SimpleImputer (strategy = 'median' )),
78
+ ('scaler' , StandardScaler ())
79
+ ])
80
+
81
+ preprocessor = ColumnTransformer ([
82
+ ('num' , numeric_transformer , X .columns )
83
+ ])
84
+
85
+ pipe = Pipeline ([
86
+ ('preprocess' , preprocessor ),
87
+ ('classifier' , GradientBoostingClassifier ())
88
+ ])
89
+
90
+ pipe .fit (X , y )
91
+
92
+ return pipe
93
+
94
+
66
95
@pytest .fixture
67
96
def pickle_file (tmpdir_factory , sklearn_model ):
68
97
"""Returns the path to a file containing a pickled Scikit-Learn model """
@@ -215,6 +244,17 @@ def test_from_python_file(python_file):
215
244
assert isinstance (p , PyMAS )
216
245
217
246
247
+ def test_with_sklearn_pipeline (train_data , sklearn_pipeline ):
248
+ from sasctl .utils .pymas import PyMAS , from_pickle
249
+
250
+ X , y = train_data
251
+ p = from_pickle (pickle .dumps (sklearn_pipeline ),
252
+ func_name = 'predict' ,
253
+ input_types = X )
254
+
255
+ assert isinstance (p , PyMAS )
256
+ assert len (p .variables ) > 4 # 4 input features in Iris data set
257
+
218
258
@pytest .mark .usefixtures ('session' )
219
259
def test_publish_and_execute (tmpdir ):
220
260
import pickle
0 commit comments