Skip to content

Commit c17e0a4

Browse files
committed
update_performance task
1 parent a2e58db commit c17e0a4

File tree

2 files changed

+83
-25
lines changed

2 files changed

+83
-25
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11

22
Unreleased
33
----------
4+
**Improvements**
5+
- Added `update_performance` task for easily uploading performance information for a model.
6+
47
**Changes**
58
- `register_model` task automatically captures installed Python packages.
69
- Improved API documentation

src/sasctl/tasks.py

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -411,15 +411,37 @@ def submit_request():
411411
return module
412412

413413

414-
def save_performance(data, model, label):
414+
def update_performance(data, model, label, exec=True):
415+
"""Upload data for calculating model performance metrics
415416
417+
Parameters
418+
----------
419+
data : Dataframe
420+
model : str or dict
421+
The name or id of the model, or a dictionary representation of
422+
the model.
423+
label : str
424+
The time period the data is from. Should be unique and will be
425+
displayed on performance charts. Examples: 'Q1', '2019', 'APR2019'.
426+
exec : bool, optional
427+
Whether to execute the performance definition with the new data
428+
429+
Returns
430+
-------
431+
CASTable
432+
The CAS table containing the performance data.
433+
434+
"""
416435
from .services import model_management as mm
417436
try:
418437
import swat
419438
except ImportError:
420439
raise RuntimeError("The 'swat' package is required to save model "
421440
"performance data.")
422441

442+
# Default to true
443+
exec = True if exec is None else exec
444+
423445
model_obj = mr.get_model(model)
424446

425447
if model_obj is None:
@@ -433,13 +455,17 @@ def save_performance(data, model, label):
433455
"Received project with '%s' function. Should be "
434456
"'Prediction' or 'Classification'.",
435457
project.get('function'))
436-
# elif project.get('targetLevel', '').lower() not in ():
437-
# raise ValueError()
438-
# elif project.get('targetLevel', ''):
439-
# raise ValueError()
440-
# elif project.get('predictionVariable', ''):
441-
# raise ValueError()
442-
458+
elif project.get('targetLevel', '').lower() not in ('interval', 'binary'):
459+
raise ValueError("Performance monitoring is currently supported for "
460+
"regression and binary classification projects. "
461+
"Received project with '%s' target level. Should be "
462+
"'Interval' or 'Binary'.", project.get('targetLevel'))
463+
elif project.get('predictionVariable', '') == '':
464+
raise ValueError("Project '%s' does not have a prediction variable "
465+
"specified." % project)
466+
467+
# Find the performance definition for the model
468+
# As of Viya 3.4, no way to search by model or project
443469
perf_def = None
444470
for p in mm.list_performance_definitions():
445471
if model_obj.id in p.modelIds:
@@ -450,27 +476,57 @@ def save_performance(data, model, label):
450476
raise ValueError("Unable to find a performance definition for model "
451477
"'%s'" % model)
452478

479+
# Check where performance datasets should be uploaded
453480
cas_id = perf_def['casServerId']
454481
caslib = perf_def['dataLibrary']
455482
table_prefix = perf_def['dataPrefix']
456483

484+
# All input variables must be present
485+
missing_cols = [col for col in perf_def.inputVariables if
486+
col not in data.columns]
487+
if len(missing_cols):
488+
raise ValueError("The following columns were expected but not found in "
489+
"the data set: %s" % ', '.join(missing_cols))
490+
491+
# If CAS is not executing the model then the output variables must also be
492+
# provided
493+
if not perf_def.scoreExecutionRequired:
494+
missing_cols = [col for col in perf_def.outputVariables if
495+
col not in data.columns]
496+
if len(missing_cols):
497+
raise ValueError(
498+
"The following columns were expected but not found in the data "
499+
"set: %s" % ', '.join(missing_cols))
500+
457501
sess = current_session()
458502
url = '{}://{}/{}-http/'.format(sess._settings['protocol'],
459503
sess.hostname,
460504
cas_id)
461-
462505
regex = r'{}_(\d)_*_{}'.format(table_prefix,
463506
model_obj.id)
507+
508+
# Upload the performance data to CAS
464509
with swat.CAS(url,
465510
username=sess.username,
466511
password=sess._settings['password']) as s:
467-
all_tables = s.table.tableinfo(caslib=caslib).TableInfo
468-
perf_tables = all_tables.Name.str.extract(regex,
469-
flags=re.IGNORECASE,
470-
expand=False)
471512

472-
last_seq = perf_tables.dropna().astype(int).max()
473-
next_seq = 1 if math.isnan(last_seq) else last_seq + 1
513+
s.setsessopt(messagelevel='warning')
514+
515+
with swat.options(exception_on_severity=2):
516+
caslib_info = s.table.tableinfo(caslib=caslib)
517+
518+
all_tables = getattr(caslib_info, 'TableInfo', None)
519+
if all_tables is not None:
520+
# Find tables with similar names
521+
perf_tables = all_tables.Name.str.extract(regex,
522+
flags=re.IGNORECASE,
523+
expand=False)
524+
525+
# Get last-used sequence number
526+
last_seq = perf_tables.dropna().astype(int).max()
527+
next_seq = 1 if math.isnan(last_seq) else last_seq + 1
528+
else:
529+
next_seq = 1
474530

475531
table_name = '{prefix}_{sequence}_{label}_{model}'.format(
476532
prefix=table_prefix,
@@ -479,18 +535,17 @@ def save_performance(data, model, label):
479535
model=model_obj.id
480536
)
481537

482-
s.upload(data, casout=dict(name=table_name, caslib=caslib))
483-
print(s)
484-
485-
538+
with swat.options(exception_on_severity=2):
539+
# Table must be promoted so performance jobs can access.
540+
tbl = s.upload(data, casout=dict(name=table_name,
541+
caslib=caslib,
542+
promote=True)).casTable
486543

487-
# does perf definition already exist?
488-
# get def and determine naming convention
489-
# get CAS connection?
490-
# determine table name
491-
# upload data
492-
# optionally run definition
544+
# Execute the definition if requested
545+
if exec:
546+
mm.execute_performance_definition(perf_def)
493547

548+
return tbl
494549

495550
"""
496551
Use one of the following formats for the name of the data table that you use as a data source, or for the name of the data tables that are located in the selected library.

0 commit comments

Comments
 (0)