1
1
import json
2
+ import time
3
+ import warnings
4
+ from distutils .version import StrictVersion
5
+ from typing import Union
2
6
7
+ import pandas as pd
3
8
from requests import HTTPError
4
9
10
+ from .cas_management import CASManagement
11
+ from ..core import current_session
5
12
from .score_definitions import ScoreDefinitions
6
13
from .service import Service
7
14
@@ -18,7 +25,9 @@ class ScoreExecution(Service):
18
25
"""
19
26
20
27
_SERVICE_ROOT = "/scoreExecution"
28
+ _cas_management = CASManagement ()
21
29
_score_definitions = ScoreDefinitions ()
30
+ _services = Service ()
22
31
23
32
(
24
33
list_executions ,
@@ -62,34 +71,17 @@ def create_score_execution(
62
71
if not score_definition :
63
72
raise HTTPError
64
73
score_exec_name = score_definition .get ("name" )
65
- model_uri = score_definition .get ("objectDescriptor" , "uri" )
66
- model_name = score_definition .get ("objectDescriptor" , "name" )
67
- model_input_library = score_definition .get ("inputData" , "libraryName" )
68
- model_table_name = score_definition .get ("inputData" , "tableName" )
74
+ # NEEDS modelManagement not modelRepository
75
+ model_uuid = score_definition .get ("objectDescriptor" ).get ("uri" ).split ('/' )[- 1 ]
76
+ model_uri = f"/modelManagement/models/{ model_uuid } "
77
+ model_name = score_definition .get ("objectDescriptor" ).get ("name" )
78
+ model_input_library = score_definition .get ("inputData" ).get ("libraryName" )
79
+ model_table_name = score_definition .get ("inputData" ).get ("tableName" )
69
80
70
81
# Defining a default output table name if none is provided
71
82
if not output_table_name :
72
83
output_table_name = f"{ model_name } _{ score_definition_id } "
73
84
74
- # Getting all score executions that are using the inputted score_definition_id
75
-
76
- # score_execution = cls.list_executions(
77
- # filter=f"eq(scoreDefinitionId, '{score_definition_id}')"
78
- # )
79
- score_execution = cls .get ("scoreExecution/executions" ,
80
- filter = f"filter=eq(scoreExecutionRequest.scoreDefinitionId,%{ score_definition_id } %27)"
81
- )
82
- if not score_execution :
83
- raise HTTPError (f"Something went wrong in the LIST_EXECUTIONS statement." )
84
-
85
- # Checking the count of the execution list to see if there are any score executions for this score_definition_id already running
86
- execution_count = score_execution .get ("count" ) # Exception catch location
87
- if execution_count == 1 :
88
- execution_id = score_execution .get ("items" , 0 , "id" )
89
- deleted_execution = cls .delete_execution (execution_id )
90
- if deleted_execution .status_code >= 400 :
91
- raise HTTPError (f"Something went wrong in the DELETE statement." )
92
-
93
85
headers_score_exec = {"Content-Type" : "application/json" }
94
86
95
87
create_score_exec = {
@@ -109,9 +101,124 @@ def create_score_execution(
109
101
}
110
102
111
103
# Creating the score execution
112
- new_score_execution = cls .post (
113
- "scoreExecution/ executions" ,
104
+ score_execution = cls .post (
105
+ "executions" ,
114
106
data = json .dumps (create_score_exec ),
115
107
headers = headers_score_exec ,
116
108
)
117
- return new_score_execution
109
+
110
+ return score_execution
111
+
112
+ @classmethod
113
+ def poll_score_execution_state (
114
+ cls ,
115
+ score_execution : Union [dict , str ],
116
+ timeout : int = 300
117
+ ):
118
+ if type (score_execution ) is str :
119
+ exec_id = score_execution
120
+ else :
121
+ exec_id = score_execution .get ("id" )
122
+
123
+ start_poll = time .time ()
124
+ while time .time () - start_poll < timeout :
125
+ score_execution_state = cls .get (f"executions/{ exec_id } /state" )
126
+ if score_execution_state == "completed" :
127
+ print ("Score execution state is 'completed'" )
128
+ return "completed"
129
+ elif score_execution_state == "failed" :
130
+ # TODO: Grab score execution logs and return those
131
+ print ("The score execution state is failed." )
132
+ return "failed"
133
+ elif time .time () - start_poll > timeout :
134
+ print ("The score execution is still running, but polling time ran out." )
135
+ return "timeout"
136
+
137
+ @classmethod
138
+ def get_score_execution_results (
139
+ cls ,
140
+ score_execution : Union [dict , str ],
141
+ ):
142
+ try :
143
+ import swat
144
+ except ImportError :
145
+ swat = None
146
+
147
+ if type (score_execution ) is str :
148
+ score_execution = cls .get_execution (score_execution )
149
+
150
+ server_name = score_execution .get ("outputTable" ).get ("serverName" )
151
+ library_name = score_execution .get ("outputTable" ).get ("libraryName" )
152
+ table_name = score_execution .get ("outputTable" ).get ("tableName" )
153
+
154
+ # If swat is not available, then
155
+ if not swat :
156
+ output_table = cls ._no_gateway_get_results (
157
+ server_name ,
158
+ library_name ,
159
+ table_name
160
+ )
161
+ return output_table
162
+ else :
163
+ session = current_session ()
164
+ cas = session .as_swat ()
165
+ response = cas .loadActionSet ("gateway" )
166
+ if not response :
167
+ output_table = cls ._no_gateway_get_results (
168
+ server_name ,
169
+ library_name ,
170
+ table_name
171
+ )
172
+ return output_table
173
+ else :
174
+ gateway_code = f"""
175
+ import pandas as pd
176
+ import numpy as np
177
+
178
+ table = gateway.read_table({{"caslib": "{ library_name } ", "name": "{ table_name } "}})
179
+
180
+ gateway.return_table("Execution Results", df = table, label = "label", title = "title")"""
181
+
182
+ output_table = cas .gateway .runlang (
183
+ code = gateway_code ,
184
+ single = True ,
185
+ timeout_millis = 10000
186
+ )
187
+ output_table = pd .DataFrame (output_table ["Execution Results" ])
188
+ return output_table
189
+
190
+ @classmethod
191
+ def _no_gateway_get_results (
192
+ cls ,
193
+ server_name ,
194
+ library_name ,
195
+ table_name
196
+ ):
197
+ if pd .__version__ >= StrictVersion ("1.0.3" ):
198
+ from pandas import json_normalize
199
+ else :
200
+ from pandas .io .json import json_normalize
201
+
202
+ warnings .warn (
203
+ "Without swat installed, the amount of rows from the output table that "
204
+ "can be collected are memory limited by the CAS worker."
205
+ )
206
+
207
+ output_columns = cls ._cas_management .get (
208
+ f"servers/{ server_name } /"
209
+ f"caslibs/{ library_name } /"
210
+ f"tables/{ table_name } /columns?limit=10000"
211
+ )
212
+ columns = json_normalize (output_columns .json (), "items" )
213
+ column_names = columns ["names" ].to_list ()
214
+
215
+ output_rows = cls ._services .get (
216
+ f"casRowSets/servers/{ server_name } "
217
+ f"caslibs/{ library_name } "
218
+ f"tables/{ table_name } /rows?limit=10000"
219
+ )
220
+ output_table = pd .DataFrame (
221
+ json_normalize (output_rows .json ()["items" ])["cells" ].to_list (),
222
+ columns = column_names
223
+ )
224
+ return output_table
0 commit comments