Skip to content

Commit 89aa8c0

Browse files
committed
Astore model import refactor
1 parent e671bab commit 89aa8c0

File tree

1 file changed

+78
-35
lines changed

1 file changed

+78
-35
lines changed

src/sasctl/tasks.py

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
import sys
1616
import warnings
1717

18+
try:
19+
import swat
20+
except ImportError:
21+
swat = None
22+
1823
from urllib.error import HTTPError
1924

2025
from . import utils
@@ -202,7 +207,7 @@ def register_model(
202207
column_name: type may be provided.
203208
version : {'new', 'latest', int}, optional
204209
Version number of the project in which the model should be created.
205-
Defaults to 'new'.
210+
Defaults to 'latest'.
206211
files : list
207212
A list of dictionaries of the form
208213
{'name': filename, 'file': filecontent}.
@@ -236,11 +241,14 @@ def register_model(
236241
.. versionchanged:: v1.4.5
237242
Added `record_packages` parameter.
238243
244+
.. versionchanged:: v1.7.4
245+
Update ASTORE handling for ease of use and removal of SAS Viya 4 score code errors
246+
239247
"""
240248
# TODO: Create new version if model already exists
241249

242250
# If version not specified, default to creating a new version
243-
version = version or "new"
251+
version = version or "latest"
244252

245253
files = files or []
246254

@@ -280,40 +288,75 @@ def register_model(
280288
# If model is a CASTable then assume it holds an ASTORE model.
281289
# Import these via a ZIP file.
282290
if "swat.cas.table.CASTable" in str(type(model)):
283-
zipfile = utils.create_package(model, input=input)
284-
285-
if create_project:
286-
outvar = []
287-
invar = []
288-
import zipfile as zp
289-
import copy
290-
291-
zipfilecopy = copy.deepcopy(zipfile)
292-
tmpzip = zp.ZipFile(zipfilecopy)
293-
if "outputVar.json" in tmpzip.namelist():
294-
outvar = json.loads(
295-
tmpzip.read("outputVar.json").decode("utf=8")
296-
) # added decode for 3.5 and older
297-
for tmp in outvar:
298-
tmp.update({"role": "output"})
299-
if "inputVar.json" in tmpzip.namelist():
300-
invar = json.loads(
301-
tmpzip.read("inputVar.json").decode("utf-8")
302-
) # added decode for 3.5 and older
303-
for tmp in invar:
304-
if tmp["role"] != "input":
305-
tmp["role"] = "input"
306-
307-
if "ModelProperties.json" in tmpzip.namelist():
308-
model_props = json.loads(
309-
tmpzip.read("ModelProperties.json").decode("utf-8")
310-
)
291+
if swat is None:
292+
raise RuntimeError("The 'swat' package is required to work with SAS models.")
293+
if not isinstance(model, swat.CASTable):
294+
raise ValueError(
295+
"Parameter 'table' should be an instance of '%r' but "
296+
"received '%r'." % (swat.CASTable, model)
297+
)
298+
if "DataStepSrc" in model.columns:
299+
zip_file = utils.create_package_from_datastep(model, input=input)
300+
if create_project:
301+
out_var = []
302+
in_var = []
303+
import zipfile as zp
304+
import copy
305+
306+
zip_file_copy = copy.deepcopy(zip_file)
307+
tmp_zip = zp.ZipFile(zip_file_copy)
308+
if "outputVar.json" in tmp_zip.namelist():
309+
out_var = json.loads(
310+
tmp_zip.read("outputVar.json").decode("utf=8")
311+
) # added decode for 3.5 and older
312+
for tmp in out_var:
313+
tmp.update({"role": "output"})
314+
if "inputVar.json" in tmp_zip.namelist():
315+
in_var = json.loads(
316+
tmp_zip.read("inputVar.json").decode("utf-8")
317+
) # added decode for 3.5 and older
318+
for tmp in in_var:
319+
if tmp["role"] != "input":
320+
tmp["role"] = "input"
321+
322+
if "ModelProperties.json" in tmp_zip.namelist():
323+
model_props = json.loads(
324+
tmp_zip.read("ModelProperties.json").decode("utf-8")
325+
)
326+
else:
327+
model_props = {}
328+
project = _create_project(project, model_props, repo_obj, in_var, out_var)
329+
model = mr.import_model_from_zip(name, project, zip_file, version=version)
330+
# Assume ASTORE model if not a DataStep model
331+
else:
332+
conn = model.session.get_connection()
333+
conn.loadactionset("astore")
334+
if create_project:
335+
result = conn.astore.describe(rstore=model, epcode=False)
336+
model_props = utils.astore._get_model_properties(result)
337+
in_var = [utils.astore.get_variable_properties(var) for var in result.InputVariables.itertuples()]
338+
for var in in_var:
339+
if not var.get("role"):
340+
var["role"] = "INPUT"
341+
out_var = [utils.astore.get_variable_properties(var) for var in result.OutputVariables.itertuples()]
342+
for var in out_var:
343+
if not var.get("role"):
344+
var["role"] = "OUTPUT"
345+
project = _create_project(project, model_props, repo_obj, in_var, out_var)
311346
else:
312-
model_props = {}
313-
314-
project = _create_project(project, model_props, repo_obj, invar, outvar)
315-
316-
model = mr.import_model_from_zip(name, project, zipfile, version=version)
347+
project = mr.get_project(project)
348+
astore = conn.astore.download(rstore=model)
349+
params = {
350+
"name": name,
351+
"projectId": project.id,
352+
"type": "ASTORE",
353+
"versionOption": version
354+
}
355+
model = mr.post(
356+
"/models",
357+
files={"files": ("{}.sasast".format(model.params["name"]), astore["blob"])},
358+
data=params
359+
)
317360
return model
318361

319362
# If the model is an scikit-learn model, generate the model dictionary

0 commit comments

Comments
 (0)