Skip to content

Commit b3c34e6

Browse files
authored
Merge pull request #42 from jameskochubasas/master
Fixing the MM performance capabilities for all models
2 parents a6e1bd7 + 5ca8491 commit b3c34e6

File tree

6 files changed

+137
-25
lines changed

6 files changed

+137
-25
lines changed

src/sasctl/_services/model_management.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,16 @@ def create_performance_definition(cls,
174174

175175
# Performance data cannot be captured unless certain project properties
176176
# have been configured.
177-
for required in ['targetVariable', 'targetLevel',
178-
'predictionVariable']:
177+
for required in ['targetVariable', 'targetLevel']:
179178
if getattr(project, required, None) is None:
180179
raise ValueError("Project %s must have the '%s' property set."
181180
% (project.name, required))
181+
if project['function'] == 'classification' and project['eventProbabilityVariable'] == None:
182+
raise ValueError("Project %s must have the 'eventProbabilityVariable' property set."
183+
% (project.name))
184+
if project['function'] == 'prediction' and project['predictionVariable'] == None:
185+
raise ValueError("Project %s must have the 'predictionVariable' property set."
186+
% (project.name))
182187

183188
request = {'projectId': project.id,
184189
'name': name or model.name + ' Performance',

src/sasctl/tasks.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def _sklearn_to_dict(model):
4949
'RandomForestClassifier': 'Forest',
5050
'DecisionTreeClassifier': 'Decision tree',
5151
'DecisionTreeRegressor': 'Decision tree',
52-
'classifier': 'Classification',
53-
'regressor': 'Prediction'}
52+
'classifier': 'classification',
53+
'regressor': 'prediction'}
5454

5555
if hasattr(model, '_final_estimator'):
5656
estimator = type(model._final_estimator)
@@ -207,10 +207,26 @@ def get_version(x):
207207
# If model is a CASTable then assume it holds an ASTORE model.
208208
# Import these via a ZIP file.
209209
if 'swat.cas.table.CASTable' in str(type(model)):
210-
zipfile = utils.create_package(model)
210+
zipfile = utils.create_package(model, input=input)
211211

212212
if create_project:
213-
project = mr.create_project(project, repo_obj)
213+
outvar=[]
214+
invar=[]
215+
import zipfile as zp
216+
import copy
217+
zipfilecopy = copy.deepcopy(zipfile)
218+
tmpzip=zp.ZipFile(zipfilecopy)
219+
if "outputVar.json" in tmpzip.namelist():
220+
outvar=json.loads(tmpzip.read("outputVar.json").decode('utf=8')) #added decode for 3.5 and older
221+
for tmp in outvar:
222+
tmp.update({'role':'output'})
223+
if "inputVar.json" in tmpzip.namelist():
224+
invar=json.loads(tmpzip.read("inputVar.json").decode('utf-8')) #added decode for 3.5 and older
225+
for tmp in invar:
226+
if tmp['role'] != 'input':
227+
tmp['role']='input'
228+
vars=invar + outvar
229+
project = mr.create_project(project, repo_obj, variables=vars)
214230

215231
model = mr.import_model_from_zip(name, project, zipfile,
216232
version=version)
@@ -302,17 +318,27 @@ def get_version(x):
302318
else:
303319
prediction_variable = None
304320

305-
project = mr.create_project(project, repo_obj,
321+
# As of Viya 3.4 the 'predictionVariable' parameter is not set during
322+
# project creation. Update the project if necessary.
323+
if function == 'prediction': #Predications require predictionVariable
324+
project = mr.create_project(project, repo_obj,
306325
variables=vars,
307326
function=model.get('function'),
308327
targetLevel=target_level,
309328
predictionVariable=prediction_variable)
310329

311-
# As of Viya 3.4 the 'predictionVariable' parameter is not set during
312-
# project creation. Update the project if necessary.
313-
if project.get('predictionVariable') != prediction_variable:
314-
project['predictionVariable'] = prediction_variable
315-
mr.update_project(project)
330+
if project.get('predictionVariable') != prediction_variable:
331+
project['predictionVariable'] = prediction_variable
332+
mr.update_project(project)
333+
else: #Classifications require eventProbabilityVariable
334+
project = mr.create_project(project, repo_obj,
335+
variables=vars,
336+
function=model.get('function'),
337+
targetLevel=target_level,
338+
eventProbabilityVariable=prediction_variable)
339+
if project.get('eventProbabilityVariable') != prediction_variable:
340+
project['eventProbabilityVariable'] = prediction_variable
341+
mr.update_project(project)
316342

317343
model = mr.create_model(model, project)
318344

@@ -506,9 +532,12 @@ def update_model_performance(data, model, label, refresh=True):
506532
"regression and binary classification projects. "
507533
"Received project with '%s' target level. Should be "
508534
"'Interval' or 'Binary'.", project.get('targetLevel'))
509-
elif project.get('predictionVariable', '') == '':
535+
elif project.get('predictionVariable', '') == '' and project.get('function', '').lower() == 'prediction':
510536
raise ValueError("Project '%s' does not have a prediction variable "
511537
"specified." % project)
538+
elif project.get('eventProbabilityVariable', '') == '' and project.get('function', '').lower() == 'classification':
539+
raise ValueError("Project '%s' does not have an Event Probability variable "
540+
"specified." % project)
512541

513542
# Find the performance definition for the model
514543
# As of Viya 3.4, no way to search by model or project

src/sasctl/utils/astore.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,21 @@
2020
swat = None
2121

2222

23-
def create_package(table):
23+
def create_package(table, input=None):
2424
"""Create an importable model package from a CAS table.
2525
2626
Parameters
2727
----------
2828
table : swat.CASTable
2929
The CAS table containing an ASTORE or score code.
30+
input : DataFrame, type, list of type, or dict of str: type, optional
31+
The expected type for each input value of the target function.
32+
Can be omitted if target function includes type hints. If a DataFrame
33+
is provided, the columns will be inspected to determine type information.
34+
If a single type is provided, all columns will be assumed to be that type,
35+
otherwise a list of column types or a dictionary of column_name: type
36+
may be provided.
37+
3038
3139
Returns
3240
-------
@@ -45,18 +53,26 @@ def create_package(table):
4553
assert isinstance(table, swat.CASTable)
4654

4755
if 'DataStepSrc' in table.columns:
48-
return create_package_from_datastep(table)
56+
#Input only passed to datastep
57+
return create_package_from_datastep(table, input=input)
4958
else:
5059
return create_package_from_astore(table)
5160

5261

53-
def create_package_from_datastep(table):
62+
def create_package_from_datastep(table, input=None):
5463
"""Create an importable model package from a score code table.
5564
5665
Parameters
5766
----------
5867
table : swat.CASTable
5968
The CAS table containing the score code.
69+
input : DataFrame, type, list of type, or dict of str: type, optional
70+
The expected type for each input value of the target function.
71+
Can be omitted if target function includes type hints. If a DataFrame
72+
is provided, the columns will be inspected to determine type information.
73+
If a single type is provided, all columns will be assumed to be that type,
74+
otherwise a list of column types or a dictionary of column_name: type
75+
may be provided.
6076
6177
Returns
6278
-------
@@ -73,11 +89,59 @@ def create_package_from_datastep(table):
7389

7490
dscode = table.to_frame().loc[0, 'DataStepSrc']
7591

92+
# Extract inputs if provided
93+
input_vars = []
94+
# Workaround because sasdataframe does not like to be check if exist
95+
if str(input) != "None":
96+
from .pymas.python import ds2_variables
97+
vars=None
98+
if hasattr(input, 'columns'):
99+
# Assuming input is a DataFrame representing model inputs. Use to
100+
# get input variables
101+
vars = ds2_variables(input)
102+
elif isinstance(input, type):
103+
params = OrderedDict([(k, input)
104+
for k in target_func.__code__.co_varnames])
105+
vars = ds2_variables(params)
106+
elif isinstance(input, dict):
107+
vars = ds2_variables(input)
108+
if vars:
109+
input_vars = [var.as_model_metadata() for var in vars if not var.out]
110+
111+
#Find outputs from ds code
112+
output_vars=[]
113+
for sasline in dscode.split('\n'):
114+
if sasline.strip().startswith('label'):
115+
output_var=dict()
116+
for tmp in sasline.split('='):
117+
if 'label' in tmp:
118+
ovarname=tmp.split('label')[1].strip()
119+
output_var.update({"name":ovarname})
120+
#Determine type of variable is decimal or string
121+
if "length " + ovarname in dscode:
122+
sastype=dscode.split("length " + ovarname)[1].split(';')[0].strip()
123+
if "$" in sastype:
124+
output_var.update({"type":"string"})
125+
output_var.update({"length":sastype.split("$")[1]})
126+
else:
127+
output_var.update({"type":"decimal"})
128+
output_var.update({"length":sastype})
129+
else:
130+
#If no length for varaible, default is decimal, 8
131+
output_var.update({"type":"decimal"})
132+
output_var.update({"length":8})
133+
else:
134+
output_var.update({"description":tmp.split(';')[0].strip().strip("'")})
135+
output_vars.append(output_var)
136+
76137
file_metadata = [{'role': 'score', 'name': 'dmcas_scorecode.sas'}]
77138

78139
zip_file = _build_zip_from_files({
79140
'fileMetadata.json': file_metadata,
80-
'dmcas_scorecode.sas': dscode
141+
'dmcas_scorecode.sas': dscode,
142+
'ModelProperties.json': {"scoreCodeType":"dataStep"},
143+
'outputVar.json': output_vars,
144+
'inputVar.json': input_vars
81145
})
82146

83147
return zip_file

src/sasctl/utils/pymas/ds2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def _map_type(cls, mapping, t):
339339

340340
def as_model_metadata(self):
341341
viya_type = self._map_type(self.DS2_TYPE_TO_VIYA, self.type)
342-
role = 'Output' if self.out else 'Input'
342+
role = 'Output' if self.out else 'input'
343343

344344
return OrderedDict(
345345
[('name', self.name), ('role', role), ('type', viya_type)])

tests/unit/test_model_management.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,20 @@ def test_create_performance_definition():
5050
with pytest.raises(ValueError):
5151
# Project missing some required properties
5252
get_project.return_value = copy.deepcopy(PROJECT)
53-
get_project.return_value['predictionVariable'] = 'predicted'
53+
get_project.return_value['function'] = 'classification'
54+
_ = mm.create_performance_definition('model', 'TestLibrary', 'TestData')
55+
56+
with pytest.raises(ValueError):
57+
# Project missing some required properties
58+
get_project.return_value = copy.deepcopy(PROJECT)
59+
get_project.return_value['function'] = 'prediction'
5460
_ = mm.create_performance_definition('model', 'TestLibrary', 'TestData')
5561

5662
get_project.return_value = copy.deepcopy(PROJECT)
5763
get_project.return_value['targetVariable'] = 'target'
5864
get_project.return_value['targetLevel'] = 'interval'
5965
get_project.return_value['predictionVariable'] = 'predicted'
66+
get_project.return_value['function'] = 'prediction'
6067
_ = mm.create_performance_definition('model', 'TestLibrary',
6168
'TestData',
6269
max_bins=3,

tests/unit/test_tasks.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,27 @@ def test_sklearn_metadata():
2222

2323
info = _sklearn_to_dict(LinearRegression())
2424
assert info['algorithm'] == 'Linear regression'
25-
assert info['function'] == 'Prediction'
25+
assert info['function'] == 'prediction'
2626

2727
info = _sklearn_to_dict(LogisticRegression())
2828
assert info['algorithm'] == 'Logistic regression'
29-
assert info['function'] == 'Classification'
29+
assert info['function'] == 'classification'
3030

3131
info = _sklearn_to_dict(SVC())
3232
assert info['algorithm'] == 'Support vector machine'
33-
assert info['function'] == 'Classification'
33+
assert info['function'] == 'classification'
3434

3535
info = _sklearn_to_dict(GradientBoostingClassifier())
3636
assert info['algorithm'] == 'Gradient boosting'
37-
assert info['function'] == 'Classification'
37+
assert info['function'] == 'classification'
3838

3939
info = _sklearn_to_dict(DecisionTreeClassifier())
4040
assert info['algorithm'] == 'Decision tree'
41-
assert info['function'] == 'Classification'
41+
assert info['function'] == 'classification'
4242

4343
info = _sklearn_to_dict(RandomForestClassifier())
4444
assert info['algorithm'] == 'Forest'
45-
assert info['function'] == 'Classification'
45+
assert info['function'] == 'classification'
4646

4747

4848
def test_parse_module_url():
@@ -96,6 +96,13 @@ def test_save_performance_project_types():
9696
project.return_value = {'function': 'Prediction',
9797
'targetLevel': 'Binary'}
9898
update_model_performance(None, None, None)
99+
100+
# Classification variable required
101+
with pytest.raises(ValueError):
102+
project.return_value = {'function': 'classification',
103+
'targetLevel': 'Binary'}
104+
update_model_performance(None, None, None)
105+
99106

100107
# Check projects w/ invalid properties
101108

0 commit comments

Comments
 (0)