@@ -49,8 +49,8 @@ def _sklearn_to_dict(model):
49
49
'RandomForestClassifier' : 'Forest' ,
50
50
'DecisionTreeClassifier' : 'Decision tree' ,
51
51
'DecisionTreeRegressor' : 'Decision tree' ,
52
- 'classifier' : 'Classification ' ,
53
- 'regressor' : 'Prediction ' }
52
+ 'classifier' : 'classification ' ,
53
+ 'regressor' : 'prediction ' }
54
54
55
55
if hasattr (model , '_final_estimator' ):
56
56
estimator = type (model ._final_estimator )
@@ -207,10 +207,26 @@ def get_version(x):
207
207
# If model is a CASTable then assume it holds an ASTORE model.
208
208
# Import these via a ZIP file.
209
209
if 'swat.cas.table.CASTable' in str (type (model )):
210
- zipfile = utils .create_package (model )
210
+ zipfile = utils .create_package (model , input = input )
211
211
212
212
if create_project :
213
- project = mr .create_project (project , repo_obj )
213
+ outvar = []
214
+ invar = []
215
+ import zipfile as zp
216
+ import copy
217
+ zipfilecopy = copy .deepcopy (zipfile )
218
+ tmpzip = zp .ZipFile (zipfilecopy )
219
+ if "outputVar.json" in tmpzip .namelist ():
220
+ outvar = json .loads (tmpzip .read ("outputVar.json" ).decode ('utf=8' )) #added decode for 3.5 and older
221
+ for tmp in outvar :
222
+ tmp .update ({'role' :'output' })
223
+ if "inputVar.json" in tmpzip .namelist ():
224
+ invar = json .loads (tmpzip .read ("inputVar.json" ).decode ('utf-8' )) #added decode for 3.5 and older
225
+ for tmp in invar :
226
+ if tmp ['role' ] != 'input' :
227
+ tmp ['role' ]= 'input'
228
+ vars = invar + outvar
229
+ project = mr .create_project (project , repo_obj , variables = vars )
214
230
215
231
model = mr .import_model_from_zip (name , project , zipfile ,
216
232
version = version )
@@ -302,17 +318,27 @@ def get_version(x):
302
318
else :
303
319
prediction_variable = None
304
320
305
- project = mr .create_project (project , repo_obj ,
321
+ # As of Viya 3.4 the 'predictionVariable' parameter is not set during
322
+ # project creation. Update the project if necessary.
323
+ if function == 'prediction' : #Predications require predictionVariable
324
+ project = mr .create_project (project , repo_obj ,
306
325
variables = vars ,
307
326
function = model .get ('function' ),
308
327
targetLevel = target_level ,
309
328
predictionVariable = prediction_variable )
310
329
311
- # As of Viya 3.4 the 'predictionVariable' parameter is not set during
312
- # project creation. Update the project if necessary.
313
- if project .get ('predictionVariable' ) != prediction_variable :
314
- project ['predictionVariable' ] = prediction_variable
315
- mr .update_project (project )
330
+ if project .get ('predictionVariable' ) != prediction_variable :
331
+ project ['predictionVariable' ] = prediction_variable
332
+ mr .update_project (project )
333
+ else : #Classifications require eventProbabilityVariable
334
+ project = mr .create_project (project , repo_obj ,
335
+ variables = vars ,
336
+ function = model .get ('function' ),
337
+ targetLevel = target_level ,
338
+ eventProbabilityVariable = prediction_variable )
339
+ if project .get ('eventProbabilityVariable' ) != prediction_variable :
340
+ project ['eventProbabilityVariable' ] = prediction_variable
341
+ mr .update_project (project )
316
342
317
343
model = mr .create_model (model , project )
318
344
@@ -506,9 +532,12 @@ def update_model_performance(data, model, label, refresh=True):
506
532
"regression and binary classification projects. "
507
533
"Received project with '%s' target level. Should be "
508
534
"'Interval' or 'Binary'." , project .get ('targetLevel' ))
509
- elif project .get ('predictionVariable' , '' ) == '' :
535
+ elif project .get ('predictionVariable' , '' ) == '' and project . get ( 'function' , '' ). lower () == 'prediction' :
510
536
raise ValueError ("Project '%s' does not have a prediction variable "
511
537
"specified." % project )
538
+ elif project .get ('eventProbabilityVariable' , '' ) == '' and project .get ('function' , '' ).lower () == 'classification' :
539
+ raise ValueError ("Project '%s' does not have an Event Probability variable "
540
+ "specified." % project )
512
541
513
542
# Find the performance definition for the model
514
543
# As of Viya 3.4, no way to search by model or project
0 commit comments