19
19
20
20
21
21
@pytest .fixture
22
- def sklearn_model ():
23
- """Returns a simple Scikit-Learn model """
22
+ def sklearn_logistic_model ():
23
+ """A Scikit-Learn logistic regression fit to Iris data set. """
24
24
25
- try :
26
- import pandas as pd
27
- except ImportError :
28
- pytest .skip ('Package `pandas` not found.' )
25
+ pd = pytest .importorskip ('pandas' )
29
26
30
27
try :
31
28
from sklearn import datasets
@@ -48,6 +45,27 @@ def sklearn_model():
48
45
return model , iris .iloc [:, 0 :4 ]
49
46
50
47
48
+ @pytest .fixture
49
+ def sklearn_linear_model ():
50
+ """A Scikit-Learn linear regression fit to Boston housing data."""
51
+
52
+ pd = pytest .importorskip ('pandas' )
53
+ datasets = pytest .importorskip ('sklearn.datasets' )
54
+ linear_model = pytest .importorskip ('sklearn.linear_model' )
55
+
56
+ data = datasets .load_boston ()
57
+ X = pd .DataFrame (data .data , columns = data .feature_names )
58
+ y = pd .DataFrame (data .target , columns = ['Price' ])
59
+
60
+ with warnings .catch_warnings ():
61
+ warnings .simplefilter ('ignore' )
62
+ lm = linear_model .LinearRegression ()
63
+ lm .fit (X , y )
64
+
65
+ return lm , X , y
66
+
67
+
68
+
51
69
@pytest .mark .incremental
52
70
class TestModels :
53
71
def test_register_astore (self , astore ):
@@ -61,11 +79,11 @@ def test_register_astore(self, astore):
61
79
assert isinstance (model , RestObj )
62
80
assert ASTORE_MODEL_NAME == model .name
63
81
64
- def test_register_sklearn (self , sklearn_model ):
82
+ def test_register_sklearn (self , sklearn_logistic_model ):
65
83
from sasctl .tasks import register_model
66
84
from sasctl import RestObj
67
85
68
- sk_model , train_df = sklearn_model
86
+ sk_model , train_df = sklearn_logistic_model
69
87
70
88
# Register model and ensure attributes are set correctly
71
89
model = register_model (sk_model , SCIKIT_MODEL_NAME ,
@@ -132,7 +150,6 @@ def test_publish_sklearn_again(self):
132
150
# MAS module should automatically have methods bound
133
151
assert callable (p .score )
134
152
135
-
136
153
def test_score_sklearn (self ):
137
154
from sasctl .services import microanalytic_score as mas
138
155
@@ -141,3 +158,82 @@ def test_score_sklearn(self):
141
158
r = m .score (sepalwidth = 1 , sepallength = 2 , petallength = 3 , petalwidth = 4 )
142
159
assert r == 'virginica'
143
160
161
+
162
+ @pytest .mark .incremental
163
+ class TestSklearnLinearModel :
164
+ MODEL_NAME = 'Scikit Linear Model'
165
+ PROJECT_NAME = 'Boston Housing'
166
+
167
+ def test_register_model (self , sklearn_linear_model ):
168
+ from sasctl .tasks import register_model
169
+ from sasctl import RestObj
170
+
171
+ sk_model , X , y = sklearn_linear_model
172
+
173
+ # Register model and ensure attributes are set correctly
174
+ model = register_model (sk_model ,
175
+ self .MODEL_NAME ,
176
+ project = self .PROJECT_NAME ,
177
+ input = X ,
178
+ force = True )
179
+
180
+ assert isinstance (model , RestObj )
181
+ assert self .MODEL_NAME == model .name
182
+ assert 'Prediction' == model .function
183
+ assert 'Linear regression' == model .algorithm
184
+ assert 'Python' == model .trainCodeType
185
+ assert 'ds2MultiType' == model .scoreCodeType
186
+
187
+ assert len (model .inputVariables ) == 13
188
+ assert len (model .outputVariables ) == 1
189
+
190
+ # Don't compare to sys.version since cassettes used may have been
191
+ # created by a different version
192
+ assert re .match ('Python \d\.\d' , model .tool )
193
+
194
+ # Ensure input & output metadata was set
195
+ for col in X .columns :
196
+ assert 1 == len ([v for v in model .inputVariables
197
+ + model .outputVariables if v .get ('name' ) == col ])
198
+
199
+ # Ensure model files were created
200
+ from sasctl .services import model_repository as mr
201
+ files = mr .get_model_contents (model )
202
+ filenames = [f .name for f in files ]
203
+ assert 'model.pkl' in filenames
204
+ assert 'dmcas_epscorecode.sas' in filenames
205
+ assert 'dmcas_packagescorecode.sas' in filenames
206
+
207
+ def test_create_performance_definition (self , sklearn_linear_model ):
208
+ from sasctl .services import model_repository as mr
209
+ from sasctl .services import model_management as mm
210
+
211
+ lm , X , y = sklearn_linear_model
212
+
213
+ project = mr .get_project (self .PROJECT_NAME )
214
+ # Update project properties
215
+ project ['function' ] = 'prediction'
216
+ project ['targetLevel' ] = 'interval'
217
+ project ['targetVariable' ] = 'Price'
218
+ project ['predictionVariable' ] = 'var1'
219
+ project = mr .update_project (project )
220
+
221
+ mm .create_performance_definition (self .MODEL_NAME , 'Public' , 'boston' )
222
+
223
+ def test_update_model_performance (self , sklearn_linear_model , cas_session ):
224
+ from six .moves import mock
225
+ from sasctl .tasks import update_model_performance
226
+
227
+ lm , X , y = sklearn_linear_model
228
+
229
+ # Score & set output var
230
+ train_df = X .copy ()
231
+ train_df ['var1' ] = lm .predict (X )
232
+ train_df ['Price' ] = y
233
+
234
+ with mock .patch ('swat.CAS' ) as CAS :
235
+ for period in ('q12019' , 'q22019' , 'q32019' , 'q42019' ):
236
+ sample = train_df .sample (frac = 0.1 )
237
+ update_model_performance (sample , self .MODEL_NAME , period )
238
+
239
+
0 commit comments