8
8
9
9
import json
10
10
import logging
11
+ import math
11
12
import pickle
12
13
import re
13
14
import sys
14
15
import warnings
15
16
16
17
from . import utils
17
- from .core import RestObj , get , get_link , request_link
18
+ from .core import RestObj , current_session , get , get_link , request_link
18
19
from .services import model_management as mm
19
20
from .services import model_publish as mp
20
21
from .services import model_repository as mr
@@ -410,18 +411,78 @@ def submit_request():
410
411
return module
411
412
412
413
413
- def save_performance (model ):
414
- # TODO: Implement & document
414
+ def save_performance (data , model , label ):
415
+
416
+ from .services import model_management as mm
417
+ try :
418
+ import swat
419
+ except ImportError :
420
+ raise RuntimeError ("The 'swat' package is required to save model "
421
+ "performance data." )
415
422
416
423
model_obj = mr .get_model (model )
417
424
418
- from .services import model_management as mm
425
+ if model_obj is None :
426
+ raise ValueError ('Model %s was not found.' , model )
427
+
428
+ project = mr .get_project (model_obj .projectId )
429
+
430
+ if project .get ('function' , '' ).lower () not in ('prediction' , 'classification' ):
431
+ raise ValueError ("Performance monitoring is currently supported for "
432
+ "regression and binary classification projects. "
433
+ "Received project with '%s' function. Should be "
434
+ "'Prediction' or 'Classification'." ,
435
+ project .get ('function' ))
436
+ # elif project.get('targetLevel', '').lower() not in ():
437
+ # raise ValueError()
438
+ # elif project.get('targetLevel', ''):
439
+ # raise ValueError()
440
+ # elif project.get('predictionVariable', ''):
441
+ # raise ValueError()
442
+
443
+ perf_def = None
444
+ for p in mm .list_performance_definitions ():
445
+ if model_obj .id in p .modelIds :
446
+ perf_def = p
447
+ break
448
+
449
+ if perf_def is None :
450
+ raise ValueError ("Unable to find a performance definition for model "
451
+ "'%s'" % model )
452
+
453
+ cas_id = perf_def ['casServerId' ]
454
+ caslib = perf_def ['dataLibrary' ]
455
+ table_prefix = perf_def ['dataPrefix' ]
456
+
457
+ sess = current_session ()
458
+ url = '{}://{}/{}-http/' .format (sess ._settings ['protocol' ],
459
+ sess .hostname ,
460
+ cas_id )
461
+
462
+ regex = r'{}_(\d)_*_{}' .format (table_prefix ,
463
+ model_obj .id )
464
+ with swat .CAS (url ,
465
+ username = sess .username ,
466
+ password = sess ._settings ['password' ]) as s :
467
+ all_tables = s .table .tableinfo (caslib = caslib ).TableInfo
468
+ perf_tables = all_tables .Name .str .extract (regex ,
469
+ flags = re .IGNORECASE ,
470
+ expand = False )
471
+
472
+ last_seq = perf_tables .dropna ().astype (int ).max ()
473
+ next_seq = 1 if math .isnan (last_seq ) else last_seq + 1
474
+
475
+ table_name = '{prefix}_{sequence}_{label}_{model}' .format (
476
+ prefix = table_prefix ,
477
+ sequence = next_seq ,
478
+ label = label ,
479
+ model = model_obj .id
480
+ )
481
+
482
+ s .upload (data , casout = dict (name = table_name , caslib = caslib ))
483
+ print (s )
484
+
419
485
420
- perf_def = mm .get_performance_definition (model_obj )
421
- # model,
422
- # table
423
- # score?
424
- # determine name
425
486
426
487
# does perf definition already exist?
427
488
# get def and determine naming convention
@@ -430,6 +491,7 @@ def save_performance(model):
430
491
# upload data
431
492
# optionally run definition
432
493
494
+
433
495
"""
434
496
Use one of the following formats for the name of the data table that you use as a data source, or for the name of the data tables that are located in the selected library.
435
497
0 commit comments