Skip to content

Commit b956ad6

Browse files
Merge branch 'scoring_services' of https://github.com/sassoftware/python-sasctl into scoring_services
2 parents 2ff6112 + a3a1886 commit b956ad6

File tree

4 files changed

+174
-33
lines changed

4 files changed

+174
-33
lines changed

src/sasctl/_services/score_definitions.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class ScoreDefinitions(Service):
2626

2727
_SERVICE_ROOT = "/scoreDefinitions"
2828
_cas_management = CASManagement()
29-
_model_respository = ModelRepository()
29+
_model_repository = ModelRepository()
3030

3131
(
3232
list_definitions,
@@ -39,7 +39,7 @@ class ScoreDefinitions(Service):
3939
def create_score_definition(
4040
cls,
4141
score_def_name: str,
42-
model_id: str,
42+
model: Union[str, dict],
4343
table_name: str,
4444
table_file: Union[str, Path] = None,
4545
description: str = "",
@@ -53,8 +53,8 @@ def create_score_definition(
5353
--------
5454
score_def_name: str
5555
Name of score definition.
56-
model_id: str
57-
A user-inputted model if where the model exists in a project.
56+
model : str or dict
57+
The name or id of the model, or a dictionary representation of the model.
5858
table_name: str
5959
A user-inputted table name in CAS Management.
6060
table_file: str or Path, optional
@@ -74,7 +74,8 @@ def create_score_definition(
7474
7575
"""
7676

77-
model = cls._model_respository.get_model(model_id)
77+
model = cls._model_repository.get_model(model)
78+
model_id = model.id
7879

7980
if not model:
8081
raise HTTPError(
@@ -122,9 +123,9 @@ def create_score_definition(
122123
"name": score_def_name,
123124
"description": description,
124125
"objectDescriptor": {
125-
"uri": f"/modelRepository/models/{model_id}",
126+
"uri": f"/modelManagement/models/{model_id}",
126127
"name": f"{model_name}({model_version})",
127-
"type": "sas.models.model",
128+
"type": "sas.models.model.python",
128129
},
129130
"inputData": {
130131
"type": "CASTable",

src/sasctl/_services/score_execution.py

Lines changed: 133 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import json
2+
import time
3+
import warnings
4+
from distutils.version import StrictVersion
5+
from typing import Union
26

7+
import pandas as pd
38
from requests import HTTPError
49

10+
from .cas_management import CASManagement
11+
from ..core import current_session
512
from .score_definitions import ScoreDefinitions
613
from .service import Service
714

@@ -18,7 +25,9 @@ class ScoreExecution(Service):
1825
"""
1926

2027
_SERVICE_ROOT = "/scoreExecution"
28+
_cas_management = CASManagement()
2129
_score_definitions = ScoreDefinitions()
30+
_services = Service()
2231

2332
(
2433
list_executions,
@@ -62,34 +71,17 @@ def create_score_execution(
6271
if not score_definition:
6372
raise HTTPError
6473
score_exec_name = score_definition.get("name")
65-
model_uri = score_definition.get("objectDescriptor", "uri")
66-
model_name = score_definition.get("objectDescriptor", "name")
67-
model_input_library = score_definition.get("inputData", "libraryName")
68-
model_table_name = score_definition.get("inputData", "tableName")
74+
# NEEDS modelManagement not modelRepository
75+
model_uuid = score_definition.get("objectDescriptor").get("uri").split('/')[-1]
76+
model_uri = f"/modelManagement/models/{model_uuid}"
77+
model_name = score_definition.get("objectDescriptor").get("name")
78+
model_input_library = score_definition.get("inputData").get("libraryName")
79+
model_table_name = score_definition.get("inputData").get("tableName")
6980

7081
# Defining a default output table name if none is provided
7182
if not output_table_name:
7283
output_table_name = f"{model_name}_{score_definition_id}"
7384

74-
# Getting all score executions that are using the inputted score_definition_id
75-
76-
# score_execution = cls.list_executions(
77-
# filter=f"eq(scoreDefinitionId, '{score_definition_id}')"
78-
# )
79-
score_execution = cls.get("scoreExecution/executions",
80-
filter=f"filter=eq(scoreExecutionRequest.scoreDefinitionId,%{score_definition_id}%27)"
81-
)
82-
if not score_execution:
83-
raise HTTPError(f"Something went wrong in the LIST_EXECUTIONS statement.")
84-
85-
# Checking the count of the execution list to see if there are any score executions for this score_definition_id already running
86-
execution_count = score_execution.get("count") # Exception catch location
87-
if execution_count == 1:
88-
execution_id = score_execution.get("items", 0, "id")
89-
deleted_execution = cls.delete_execution(execution_id)
90-
if deleted_execution.status_code >= 400:
91-
raise HTTPError(f"Something went wrong in the DELETE statement.")
92-
9385
headers_score_exec = {"Content-Type": "application/json"}
9486

9587
create_score_exec = {
@@ -109,9 +101,124 @@ def create_score_execution(
109101
}
110102

111103
# Creating the score execution
112-
new_score_execution = cls.post(
113-
"scoreExecution/executions",
104+
score_execution = cls.post(
105+
"executions",
114106
data=json.dumps(create_score_exec),
115107
headers=headers_score_exec,
116108
)
117-
return new_score_execution
109+
110+
return score_execution
111+
112+
@classmethod
113+
def poll_score_execution_state(
114+
cls,
115+
score_execution: Union[dict, str],
116+
timeout: int = 300
117+
):
118+
if type(score_execution) is str:
119+
exec_id = score_execution
120+
else:
121+
exec_id = score_execution.get("id")
122+
123+
start_poll = time.time()
124+
while time.time() - start_poll < timeout:
125+
score_execution_state = cls.get(f"executions/{exec_id}/state")
126+
if score_execution_state == "completed":
127+
print("Score execution state is 'completed'")
128+
return "completed"
129+
elif score_execution_state == "failed":
130+
# TODO: Grab score execution logs and return those
131+
print("The score execution state is failed.")
132+
return "failed"
133+
elif time.time() - start_poll > timeout:
134+
print("The score execution is still running, but polling time ran out.")
135+
return "timeout"
136+
137+
@classmethod
138+
def get_score_execution_results(
139+
cls,
140+
score_execution: Union[dict, str],
141+
):
142+
try:
143+
import swat
144+
except ImportError:
145+
swat = None
146+
147+
if type(score_execution) is str:
148+
score_execution = cls.get_execution(score_execution)
149+
150+
server_name = score_execution.get("outputTable").get("serverName")
151+
library_name = score_execution.get("outputTable").get("libraryName")
152+
table_name = score_execution.get("outputTable").get("tableName")
153+
154+
# If swat is not available, then
155+
if not swat:
156+
output_table = cls._no_gateway_get_results(
157+
server_name,
158+
library_name,
159+
table_name
160+
)
161+
return output_table
162+
else:
163+
session = current_session()
164+
cas = session.as_swat()
165+
response = cas.loadActionSet("gateway")
166+
if not response:
167+
output_table = cls._no_gateway_get_results(
168+
server_name,
169+
library_name,
170+
table_name
171+
)
172+
return output_table
173+
else:
174+
gateway_code = f"""
175+
import pandas as pd
176+
import numpy as np
177+
178+
table = gateway.read_table({{"caslib": "{library_name}", "name": "{table_name}"}})
179+
180+
gateway.return_table("Execution Results", df = table, label = "label", title = "title")"""
181+
182+
output_table = cas.gateway.runlang(
183+
code=gateway_code,
184+
single=True,
185+
timeout_millis=10000
186+
)
187+
output_table = pd.DataFrame(output_table["Execution Results"])
188+
return output_table
189+
190+
@classmethod
191+
def _no_gateway_get_results(
192+
cls,
193+
server_name,
194+
library_name,
195+
table_name
196+
):
197+
if pd.__version__ >= StrictVersion("1.0.3"):
198+
from pandas import json_normalize
199+
else:
200+
from pandas.io.json import json_normalize
201+
202+
warnings.warn(
203+
"Without swat installed, the amount of rows from the output table that "
204+
"can be collected are memory limited by the CAS worker."
205+
)
206+
207+
output_columns = cls._cas_management.get(
208+
f"servers/{server_name}/"
209+
f"caslibs/{library_name}/"
210+
f"tables/{table_name}/columns?limit=10000"
211+
)
212+
columns = json_normalize(output_columns.json(), "items")
213+
column_names = columns["names"].to_list()
214+
215+
output_rows = cls._services.get(
216+
f"casRowSets/servers/{server_name}"
217+
f"caslibs/{library_name}"
218+
f"tables/{table_name}/rows?limit=10000"
219+
)
220+
output_table = pd.DataFrame(
221+
json_normalize(output_rows.json()["items"])["cells"].to_list(),
222+
columns=column_names
223+
)
224+
return output_table

src/sasctl/services.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from ._services.report_images import ReportImages as report_images
2222
from ._services.reports import Reports as reports
2323
from ._services.saslogon import SASLogon as saslogon
24+
from ._services.score_definitions import ScoreDefinitions as score_definitions
25+
from ._services.score_execution import ScoreExecution as score_execution
2426
from ._services.sentiment_analysis import SentimentAnalysis as sentiment_analysis
2527
from ._services.text_categorization import TextCategorization as text_categorization
2628
from ._services.text_parsing import TextParsing as text_parsing

src/sasctl/tasks.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import pickle # skipcq BAN-B301
1414
import re
1515
import sys
16+
from pathlib import Path
17+
from typing import Union
1618
from warnings import warn
1719

1820
import pandas as pd
@@ -30,6 +32,8 @@
3032
from .services import model_management as mm
3133
from .services import model_publish as mp
3234
from .services import model_repository as mr
35+
from .services import score_definitions as sd
36+
from .services import score_execution as se
3337
from .utils.misc import installed_packages
3438
from .utils.pymas import from_pickle
3539

@@ -1008,3 +1012,30 @@ def get_project_kpis(
10081012
kpiTableDf = kpiTableDf.apply(lambda x: x.str.strip()).replace([".", ""], None)
10091013

10101014
return kpiTableDf
1015+
1016+
1017+
def score_model_with_cas(
1018+
score_def_name: str,
1019+
model: Union[str, dict],
1020+
table_name: str,
1021+
table_file: Union[str, Path] = None,
1022+
description: str = "",
1023+
server_name: str = "cas-shared-default",
1024+
library_name: str = "Public",
1025+
model_version: str = "latest"
1026+
):
1027+
score_definition = sd.create_score_definition(
1028+
score_def_name,
1029+
model,
1030+
table_name,
1031+
table_file=table_file,
1032+
description=description,
1033+
server_name=server_name,
1034+
library_name=library_name,
1035+
model_version=model_version
1036+
)
1037+
score_execution = se.create_score_execution(score_definition.id)
1038+
score_execution_poll = se.poll_score_execution_state(score_execution)
1039+
print(score_execution_poll)
1040+
score_results = se.get_score_execution_results(score_execution)
1041+
return score_results

0 commit comments

Comments
 (0)