Skip to content

Commit 4c8559c

Browse files
committed
feat: first draft for polling score execution and returning results
1 parent a3a618f commit 4c8559c

File tree

1 file changed

+130
-20
lines changed

1 file changed

+130
-20
lines changed

src/sasctl/_services/score_execution.py

Lines changed: 130 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import json
2+
import time
3+
import warnings
4+
from distutils.version import StrictVersion
5+
from typing import Union
26

7+
import pandas as pd
38
from requests import HTTPError
49

10+
from .cas_management import CASManagement
11+
from ..core import current_session
512
from .score_definitions import ScoreDefinitions
613
from .service import Service
714

@@ -18,7 +25,9 @@ class ScoreExecution(Service):
1825
"""
1926

2027
_SERVICE_ROOT = "/scoreExecution"
28+
_cas_management = CASManagement()
2129
_score_definitions = ScoreDefinitions()
30+
_services = Service()
2231

2332
(
2433
list_executions,
@@ -62,30 +71,32 @@ def create_score_execution(
6271
if not score_definition:
6372
raise HTTPError
6473
score_exec_name = score_definition.get("name")
65-
model_uri = score_definition.get("objectDescriptor", "uri")
66-
model_name = score_definition.get("objectDescriptor", "name")
67-
model_input_library = score_definition.get("inputData", "libraryName")
68-
model_table_name = score_definition.get("inputData", "tableName")
74+
# NEEDS modelManagement not modelRepository
75+
model_uuid = score_definition.get("objectDescriptor").get("uri").split('/')[-1]
76+
model_uri = f"/modelManagement/models/{model_uuid}"
77+
model_name = score_definition.get("objectDescriptor").get("name")
78+
model_input_library = score_definition.get("inputData").get("libraryName")
79+
model_table_name = score_definition.get("inputData").get("tableName")
6980

7081
# Defining a default output table name if none is provided
7182
if not output_table_name:
7283
output_table_name = f"{model_name}_{score_definition_id}"
7384

7485
# Getting all score executions that are using the inputted score_definition_id
7586

76-
score_execution = cls.list_executions(
77-
filter=f"eq(scoreDefinitionId, '{score_definition_id}')"
78-
)
79-
if not score_execution:
80-
raise HTTPError(f"Something went wrong in the LIST_EXECUTIONS statement.")
81-
82-
# Checking the count of the execution list to see if there are any score executions for this score_definition_id already running
83-
execution_count = score_execution.get("count") # Exception catch location
84-
if execution_count == 1:
85-
execution_id = score_execution.get("items", 0, "id")
86-
deleted_execution = cls.delete_execution(execution_id)
87-
if deleted_execution.status_code >= 400:
88-
raise HTTPError(f"Something went wrong in the DELETE statement.")
87+
# score_execution = cls.list_executions(
88+
# filter=f"eq(scoreDefinitionId, '{score_definition_id}')"
89+
# )
90+
# if not score_execution:
91+
# raise HTTPError(f"Something went wrong in the LIST_EXECUTIONS statement.")
92+
#
93+
# # Checking the count of the execution list to see if there are any score executions for this score_definition_id already running
94+
# execution_count = score_execution.get("count") # Exception catch location
95+
# if execution_count == 1:
96+
# execution_id = score_execution.get("items", 0, "id")
97+
# deleted_execution = cls.delete_execution(execution_id)
98+
# if deleted_execution.status_code >= 400:
99+
# raise HTTPError(f"Something went wrong in the DELETE statement.")
89100

90101
headers_score_exec = {"Content-Type": "application/json"}
91102

@@ -106,9 +117,108 @@ def create_score_execution(
106117
}
107118

108119
# Creating the score execution
109-
new_score_execution = cls.post(
110-
"scoreExecution/executions",
120+
score_execution = cls.post(
121+
"executions",
111122
data=json.dumps(create_score_exec),
112123
headers=headers_score_exec,
113124
)
114-
return new_score_execution
125+
126+
return score_execution
127+
128+
@classmethod
129+
def poll_score_execution_state(
130+
cls,
131+
score_execution: Union[dict, str],
132+
timeout: int = 300
133+
):
134+
if type(score_execution) is str:
135+
exec_id = score_execution
136+
else:
137+
exec_id = score_execution.get("id")
138+
139+
start_poll = time.time()
140+
while time.time() - start_poll < timeout:
141+
score_execution_state = cls.get(f"executions/{exec_id}/state")
142+
if score_execution_state.text == "complete":
143+
print("Score execution state is 'complete'")
144+
return "complete"
145+
elif score_execution_state.text == "failed":
146+
# TODO: Grab score execution logs and return those
147+
print("The score execution state is failed.")
148+
return "failed"
149+
elif time.time() - start_poll > timeout:
150+
print("The score execution is still running, but polling time ran out.")
151+
return "timeout"
152+
153+
@classmethod
154+
def get_score_execution_results(
155+
cls,
156+
score_execution: Union[dict, str],
157+
):
158+
try:
159+
import swat
160+
except ImportError:
161+
swat = None
162+
163+
if type(score_execution) is str:
164+
score_execution = cls.get_execution(score_execution)
165+
166+
server_name = score_execution.get("outputTable").get("serverName")
167+
library_name = score_execution.get("outputTable").get("libraryName")
168+
table_name = score_execution.get("outputTable").get("tableName")
169+
170+
# If swat is not available, then
171+
if not swat:
172+
if pd.__version__ >= StrictVersion("1.0.3"):
173+
from pandas import json_normalize
174+
else:
175+
from pandas.io.json import json_normalize
176+
177+
warnings.warn(
178+
"Without swat installed, the amount of rows from the output table that "
179+
"can be collected are memory limited by the CAS worker."
180+
)
181+
182+
output_columns = cls._cas_management.get(
183+
f"servers/{server_name}/"
184+
f"caslibs/{library_name}/"
185+
f"tables/{table_name}/columns?limit=10000"
186+
)
187+
columns = json_normalize(output_columns.json(), "items")
188+
column_names = columns["names"].to_list()
189+
190+
output_rows = cls._services.get(
191+
f"casRowSets/servers/{server_name}"
192+
f"caslibs/{library_name}"
193+
f"tables/{table_name}/rows?limit=10000"
194+
)
195+
output_table = pd.DataFrame(
196+
json_normalize(output_rows.json()["items"])["cells"].to_list(),
197+
columns=column_names
198+
)
199+
return output_table
200+
else:
201+
session = current_session()
202+
cas = session.as_swat()
203+
cas.loadActionSet("gateway")
204+
205+
gateway_code = f"""
206+
import pandas as pd
207+
import numpy as np
208+
209+
table = gateway.read_table({{"caslib": {library_name}, "name": {table_name}}}
210+
211+
gateway.return_table(
212+
"Execution Results",
213+
df = table,
214+
label = "label",
215+
title = "title"
216+
)
217+
"""
218+
219+
output_table = cas.gateway.runlang(
220+
code=gateway_code,
221+
single=True,
222+
timeout_millis=10000
223+
)
224+
return output_table

0 commit comments

Comments
 (0)