@@ -411,15 +411,37 @@ def submit_request():
411
411
return module
412
412
413
413
414
- def save_performance (data , model , label ):
414
+ def update_performance (data , model , label , exec = True ):
415
+ """Upload data for calculating model performance metrics
415
416
417
+ Parameters
418
+ ----------
419
+ data : Dataframe
420
+ model : str or dict
421
+ The name or id of the model, or a dictionary representation of
422
+ the model.
423
+ label : str
424
+ The time period the data is from. Should be unique and will be
425
+ displayed on performance charts. Examples: 'Q1', '2019', 'APR2019'.
426
+ exec : bool, optional
427
+ Whether to execute the performance definition with the new data
428
+
429
+ Returns
430
+ -------
431
+ CASTable
432
+ The CAS table containing the performance data.
433
+
434
+ """
416
435
from .services import model_management as mm
417
436
try :
418
437
import swat
419
438
except ImportError :
420
439
raise RuntimeError ("The 'swat' package is required to save model "
421
440
"performance data." )
422
441
442
+ # Default to true
443
+ exec = True if exec is None else exec
444
+
423
445
model_obj = mr .get_model (model )
424
446
425
447
if model_obj is None :
@@ -433,13 +455,17 @@ def save_performance(data, model, label):
433
455
"Received project with '%s' function. Should be "
434
456
"'Prediction' or 'Classification'." ,
435
457
project .get ('function' ))
436
- # elif project.get('targetLevel', '').lower() not in ():
437
- # raise ValueError()
438
- # elif project.get('targetLevel', ''):
439
- # raise ValueError()
440
- # elif project.get('predictionVariable', ''):
441
- # raise ValueError()
442
-
458
+ elif project .get ('targetLevel' , '' ).lower () not in ('interval' , 'binary' ):
459
+ raise ValueError ("Performance monitoring is currently supported for "
460
+ "regression and binary classification projects. "
461
+ "Received project with '%s' target level. Should be "
462
+ "'Interval' or 'Binary'." , project .get ('targetLevel' ))
463
+ elif project .get ('predictionVariable' , '' ) == '' :
464
+ raise ValueError ("Project '%s' does not have a prediction variable "
465
+ "specified." % project )
466
+
467
+ # Find the performance definition for the model
468
+ # As of Viya 3.4, no way to search by model or project
443
469
perf_def = None
444
470
for p in mm .list_performance_definitions ():
445
471
if model_obj .id in p .modelIds :
@@ -450,27 +476,57 @@ def save_performance(data, model, label):
450
476
raise ValueError ("Unable to find a performance definition for model "
451
477
"'%s'" % model )
452
478
479
+ # Check where performance datasets should be uploaded
453
480
cas_id = perf_def ['casServerId' ]
454
481
caslib = perf_def ['dataLibrary' ]
455
482
table_prefix = perf_def ['dataPrefix' ]
456
483
484
+ # All input variables must be present
485
+ missing_cols = [col for col in perf_def .inputVariables if
486
+ col not in data .columns ]
487
+ if len (missing_cols ):
488
+ raise ValueError ("The following columns were expected but not found in "
489
+ "the data set: %s" % ', ' .join (missing_cols ))
490
+
491
+ # If CAS is not executing the model then the output variables must also be
492
+ # provided
493
+ if not perf_def .scoreExecutionRequired :
494
+ missing_cols = [col for col in perf_def .outputVariables if
495
+ col not in data .columns ]
496
+ if len (missing_cols ):
497
+ raise ValueError (
498
+ "The following columns were expected but not found in the data "
499
+ "set: %s" % ', ' .join (missing_cols ))
500
+
457
501
sess = current_session ()
458
502
url = '{}://{}/{}-http/' .format (sess ._settings ['protocol' ],
459
503
sess .hostname ,
460
504
cas_id )
461
-
462
505
regex = r'{}_(\d)_*_{}' .format (table_prefix ,
463
506
model_obj .id )
507
+
508
+ # Upload the performance data to CAS
464
509
with swat .CAS (url ,
465
510
username = sess .username ,
466
511
password = sess ._settings ['password' ]) as s :
467
- all_tables = s .table .tableinfo (caslib = caslib ).TableInfo
468
- perf_tables = all_tables .Name .str .extract (regex ,
469
- flags = re .IGNORECASE ,
470
- expand = False )
471
512
472
- last_seq = perf_tables .dropna ().astype (int ).max ()
473
- next_seq = 1 if math .isnan (last_seq ) else last_seq + 1
513
+ s .setsessopt (messagelevel = 'warning' )
514
+
515
+ with swat .options (exception_on_severity = 2 ):
516
+ caslib_info = s .table .tableinfo (caslib = caslib )
517
+
518
+ all_tables = getattr (caslib_info , 'TableInfo' , None )
519
+ if all_tables is not None :
520
+ # Find tables with similar names
521
+ perf_tables = all_tables .Name .str .extract (regex ,
522
+ flags = re .IGNORECASE ,
523
+ expand = False )
524
+
525
+ # Get last-used sequence number
526
+ last_seq = perf_tables .dropna ().astype (int ).max ()
527
+ next_seq = 1 if math .isnan (last_seq ) else last_seq + 1
528
+ else :
529
+ next_seq = 1
474
530
475
531
table_name = '{prefix}_{sequence}_{label}_{model}' .format (
476
532
prefix = table_prefix ,
@@ -479,18 +535,17 @@ def save_performance(data, model, label):
479
535
model = model_obj .id
480
536
)
481
537
482
- s .upload (data , casout = dict (name = table_name , caslib = caslib ))
483
- print (s )
484
-
485
-
538
+ with swat .options (exception_on_severity = 2 ):
539
+ # Table must be promoted so performance jobs can access.
540
+ tbl = s .upload (data , casout = dict (name = table_name ,
541
+ caslib = caslib ,
542
+ promote = True )).casTable
486
543
487
- # does perf definition already exist?
488
- # get def and determine naming convention
489
- # get CAS connection?
490
- # determine table name
491
- # upload data
492
- # optionally run definition
544
+ # Execute the definition if requested
545
+ if exec :
546
+ mm .execute_performance_definition (perf_def )
493
547
548
+ return tbl
494
549
495
550
"""
496
551
Use one of the following formats for the name of the data table that you use as a data source, or for the name of the data tables that are located in the selected library.
0 commit comments