Skip to content

Commit 501f846

Browse files
committed
refactor: adjust score execution to check for valid cas gateway actionset before scoring
1 parent a841aac commit 501f846

File tree

1 file changed

+65
-37
lines changed

1 file changed

+65
-37
lines changed

src/sasctl/_services/score_execution.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ def poll_score_execution_state(
139139
start_poll = time.time()
140140
while time.time() - start_poll < timeout:
141141
score_execution_state = cls.get(f"executions/{exec_id}/state")
142-
if score_execution_state.text == "complete":
143-
print("Score execution state is 'complete'")
144-
return "complete"
145-
elif score_execution_state.text == "failed":
142+
if score_execution_state == "completed":
143+
print("Score execution state is 'completed'")
144+
return "completed"
145+
elif score_execution_state == "failed":
146146
# TODO: Grab score execution logs and return those
147147
print("The score execution state is failed.")
148148
return "failed"
@@ -169,20 +169,10 @@ def get_score_execution_results(
169169

170170
# If swat is not available, then
171171
if not swat:
172-
if pd.__version__ >= StrictVersion("1.0.3"):
173-
from pandas import json_normalize
174-
else:
175-
from pandas.io.json import json_normalize
176-
177-
warnings.warn(
178-
"Without swat installed, the amount of rows from the output table that "
179-
"can be collected are memory limited by the CAS worker."
180-
)
181-
182-
output_columns = cls._cas_management.get(
183-
f"servers/{server_name}/"
184-
f"caslibs/{library_name}/"
185-
f"tables/{table_name}/columns?limit=10000"
172+
output_table = cls._no_gateway_get_results(
173+
server_name,
174+
library_name,
175+
table_name
186176
)
187177
columns = json_normalize(output_columns.json(), "items")
188178
column_names = columns["names"].to_list()
@@ -200,25 +190,63 @@ def get_score_execution_results(
200190
else:
201191
session = current_session()
202192
cas = session.as_swat()
203-
cas.loadActionSet("gateway")
204-
205-
gateway_code = f"""
206-
import pandas as pd
207-
import numpy as np
193+
response = cas.loadActionSet("gateway")
194+
if not response:
195+
output_table = cls._no_gateway_get_results(
196+
server_name,
197+
library_name,
198+
table_name
199+
)
200+
return output_table
201+
else:
202+
gateway_code = f"""
203+
import pandas as pd
204+
import numpy as np
208205
209-
table = gateway.read_table({{"caslib": {library_name}, "name": {table_name}}}
206+
table = gateway.read_table({{"caslib": "{library_name}", "name": "{table_name}"}})
210207
211-
gateway.return_table(
212-
"Execution Results",
213-
df = table,
214-
label = "label",
215-
title = "title"
216-
)
217-
"""
208+
gateway.return_table("Execution Results", df = table, label = "label", title = "title")"""
218209

219-
output_table = cas.gateway.runlang(
220-
code=gateway_code,
221-
single=True,
222-
timeout_millis=10000
223-
)
224-
return output_table
210+
output_table = cas.gateway.runlang(
211+
code=gateway_code,
212+
single=True,
213+
timeout_millis=10000
214+
)
215+
output_table = pd.DataFrame(output_table["Execution Results"])
216+
return output_table
217+
218+
@classmethod
219+
def _no_gateway_get_results(
220+
cls,
221+
server_name,
222+
library_name,
223+
table_name
224+
):
225+
if pd.__version__ >= StrictVersion("1.0.3"):
226+
from pandas import json_normalize
227+
else:
228+
from pandas.io.json import json_normalize
229+
230+
warnings.warn(
231+
"Without swat installed, the amount of rows from the output table that "
232+
"can be collected are memory limited by the CAS worker."
233+
)
234+
235+
output_columns = cls._cas_management.get(
236+
f"servers/{server_name}/"
237+
f"caslibs/{library_name}/"
238+
f"tables/{table_name}/columns?limit=10000"
239+
)
240+
columns = json_normalize(output_columns.json(), "items")
241+
column_names = columns["names"].to_list()
242+
243+
output_rows = cls._services.get(
244+
f"casRowSets/servers/{server_name}"
245+
f"caslibs/{library_name}"
246+
f"tables/{table_name}/rows?limit=10000"
247+
)
248+
output_table = pd.DataFrame(
249+
json_normalize(output_rows.json()["items"])["cells"].to_list(),
250+
columns=column_names
251+
)
252+
return output_table

0 commit comments

Comments
 (0)