Skip to content

Commit a2e58db

Browse files
committed
construct perf table names
1 parent 839b472 commit a2e58db

File tree

1 file changed

+71
-9
lines changed

1 file changed

+71
-9
lines changed

src/sasctl/tasks.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
import json
1010
import logging
11+
import math
1112
import pickle
1213
import re
1314
import sys
1415
import warnings
1516

1617
from . import utils
17-
from .core import RestObj, get, get_link, request_link
18+
from .core import RestObj, current_session, get, get_link, request_link
1819
from .services import model_management as mm
1920
from .services import model_publish as mp
2021
from .services import model_repository as mr
@@ -410,18 +411,78 @@ def submit_request():
410411
return module
411412

412413

413-
def save_performance(model):
414-
# TODO: Implement & document
414+
def save_performance(data, model, label):
415+
416+
from .services import model_management as mm
417+
try:
418+
import swat
419+
except ImportError:
420+
raise RuntimeError("The 'swat' package is required to save model "
421+
"performance data.")
415422

416423
model_obj = mr.get_model(model)
417424

418-
from .services import model_management as mm
425+
if model_obj is None:
426+
raise ValueError('Model %s was not found.', model)
427+
428+
project = mr.get_project(model_obj.projectId)
429+
430+
if project.get('function', '').lower() not in ('prediction', 'classification'):
431+
raise ValueError("Performance monitoring is currently supported for "
432+
"regression and binary classification projects. "
433+
"Received project with '%s' function. Should be "
434+
"'Prediction' or 'Classification'.",
435+
project.get('function'))
436+
# elif project.get('targetLevel', '').lower() not in ():
437+
# raise ValueError()
438+
# elif project.get('targetLevel', ''):
439+
# raise ValueError()
440+
# elif project.get('predictionVariable', ''):
441+
# raise ValueError()
442+
443+
perf_def = None
444+
for p in mm.list_performance_definitions():
445+
if model_obj.id in p.modelIds:
446+
perf_def = p
447+
break
448+
449+
if perf_def is None:
450+
raise ValueError("Unable to find a performance definition for model "
451+
"'%s'" % model)
452+
453+
cas_id = perf_def['casServerId']
454+
caslib = perf_def['dataLibrary']
455+
table_prefix = perf_def['dataPrefix']
456+
457+
sess = current_session()
458+
url = '{}://{}/{}-http/'.format(sess._settings['protocol'],
459+
sess.hostname,
460+
cas_id)
461+
462+
regex = r'{}_(\d)_*_{}'.format(table_prefix,
463+
model_obj.id)
464+
with swat.CAS(url,
465+
username=sess.username,
466+
password=sess._settings['password']) as s:
467+
all_tables = s.table.tableinfo(caslib=caslib).TableInfo
468+
perf_tables = all_tables.Name.str.extract(regex,
469+
flags=re.IGNORECASE,
470+
expand=False)
471+
472+
last_seq = perf_tables.dropna().astype(int).max()
473+
next_seq = 1 if math.isnan(last_seq) else last_seq + 1
474+
475+
table_name = '{prefix}_{sequence}_{label}_{model}'.format(
476+
prefix=table_prefix,
477+
sequence=next_seq,
478+
label=label,
479+
model=model_obj.id
480+
)
481+
482+
s.upload(data, casout=dict(name=table_name, caslib=caslib))
483+
print(s)
484+
419485

420-
perf_def = mm.get_performance_definition(model_obj)
421-
# model,
422-
# table
423-
# score?
424-
# determine name
425486

426487
# does perf definition already exist?
427488
# get def and determine naming convention
@@ -430,6 +491,7 @@ def save_performance(model):
430491
# upload data
431492
# optionally run definition
432493

494+
433495
"""
434496
Use one of the following formats for the name of the data table that you use as a data source, or for the name of the data tables that are located in the selected library.
435497

0 commit comments

Comments
 (0)