|
15 | 15 | import sys
|
16 | 16 | import warnings
|
17 | 17 |
|
| 18 | +try: |
| 19 | + import swat |
| 20 | +except ImportError: |
| 21 | + swat = None |
| 22 | + |
18 | 23 | from urllib.error import HTTPError
|
19 | 24 |
|
20 | 25 | from . import utils
|
@@ -202,7 +207,7 @@ def register_model(
|
202 | 207 | column_name: type may be provided.
|
203 | 208 | version : {'new', 'latest', int}, optional
|
204 | 209 | Version number of the project in which the model should be created.
|
205 |
| - Defaults to 'new'. |
| 210 | + Defaults to 'latest'. |
206 | 211 | files : list
|
207 | 212 | A list of dictionaries of the form
|
208 | 213 | {'name': filename, 'file': filecontent}.
|
@@ -236,11 +241,14 @@ def register_model(
|
236 | 241 | .. versionchanged:: v1.4.5
|
237 | 242 | Added `record_packages` parameter.
|
238 | 243 |
|
| 244 | + .. versionchanged:: v1.7.4 |
| 245 | + Update ASTORE handling for ease of use and removal of SAS Viya 4 score code errors |
| 246 | +
|
239 | 247 | """
|
240 | 248 | # TODO: Create new version if model already exists
|
241 | 249 |
|
242 | 250 | # If version not specified, default to creating a new version
|
243 |
| - version = version or "new" |
| 251 | + version = version or "latest" |
244 | 252 |
|
245 | 253 | files = files or []
|
246 | 254 |
|
@@ -280,40 +288,75 @@ def register_model(
|
280 | 288 | # If model is a CASTable then assume it holds an ASTORE model.
|
281 | 289 | # Import these via a ZIP file.
|
282 | 290 | if "swat.cas.table.CASTable" in str(type(model)):
|
283 |
| - zipfile = utils.create_package(model, input=input) |
284 |
| - |
285 |
| - if create_project: |
286 |
| - outvar = [] |
287 |
| - invar = [] |
288 |
| - import zipfile as zp |
289 |
| - import copy |
290 |
| - |
291 |
| - zipfilecopy = copy.deepcopy(zipfile) |
292 |
| - tmpzip = zp.ZipFile(zipfilecopy) |
293 |
| - if "outputVar.json" in tmpzip.namelist(): |
294 |
| - outvar = json.loads( |
295 |
| - tmpzip.read("outputVar.json").decode("utf=8") |
296 |
| - ) # added decode for 3.5 and older |
297 |
| - for tmp in outvar: |
298 |
| - tmp.update({"role": "output"}) |
299 |
| - if "inputVar.json" in tmpzip.namelist(): |
300 |
| - invar = json.loads( |
301 |
| - tmpzip.read("inputVar.json").decode("utf-8") |
302 |
| - ) # added decode for 3.5 and older |
303 |
| - for tmp in invar: |
304 |
| - if tmp["role"] != "input": |
305 |
| - tmp["role"] = "input" |
306 |
| - |
307 |
| - if "ModelProperties.json" in tmpzip.namelist(): |
308 |
| - model_props = json.loads( |
309 |
| - tmpzip.read("ModelProperties.json").decode("utf-8") |
310 |
| - ) |
| 291 | + if swat is None: |
| 292 | + raise RuntimeError("The 'swat' package is required to work with SAS models.") |
| 293 | + if not isinstance(model, swat.CASTable): |
| 294 | + raise ValueError( |
| 295 | + "Parameter 'table' should be an instance of '%r' but " |
| 296 | + "received '%r'." % (swat.CASTable, model) |
| 297 | + ) |
| 298 | + if "DataStepSrc" in model.columns: |
| 299 | + zip_file = utils.create_package_from_datastep(model, input=input) |
| 300 | + if create_project: |
| 301 | + out_var = [] |
| 302 | + in_var = [] |
| 303 | + import zipfile as zp |
| 304 | + import copy |
| 305 | + |
| 306 | + zip_file_copy = copy.deepcopy(zip_file) |
| 307 | + tmp_zip = zp.ZipFile(zip_file_copy) |
| 308 | + if "outputVar.json" in tmp_zip.namelist(): |
| 309 | + out_var = json.loads( |
| 310 | + tmp_zip.read("outputVar.json").decode("utf=8") |
| 311 | + ) # added decode for 3.5 and older |
| 312 | + for tmp in out_var: |
| 313 | + tmp.update({"role": "output"}) |
| 314 | + if "inputVar.json" in tmp_zip.namelist(): |
| 315 | + in_var = json.loads( |
| 316 | + tmp_zip.read("inputVar.json").decode("utf-8") |
| 317 | + ) # added decode for 3.5 and older |
| 318 | + for tmp in in_var: |
| 319 | + if tmp["role"] != "input": |
| 320 | + tmp["role"] = "input" |
| 321 | + |
| 322 | + if "ModelProperties.json" in tmp_zip.namelist(): |
| 323 | + model_props = json.loads( |
| 324 | + tmp_zip.read("ModelProperties.json").decode("utf-8") |
| 325 | + ) |
| 326 | + else: |
| 327 | + model_props = {} |
| 328 | + project = _create_project(project, model_props, repo_obj, in_var, out_var) |
| 329 | + model = mr.import_model_from_zip(name, project, zip_file, version=version) |
| 330 | + # Assume ASTORE model if not a DataStep model |
| 331 | + else: |
| 332 | + conn = model.session.get_connection() |
| 333 | + conn.loadactionset("astore") |
| 334 | + if create_project: |
| 335 | + result = conn.astore.describe(rstore=model, epcode=False) |
| 336 | + model_props = utils.astore._get_model_properties(result) |
| 337 | + in_var = [utils.astore.get_variable_properties(var) for var in result.InputVariables.itertuples()] |
| 338 | + for var in in_var: |
| 339 | + if not var.get("role"): |
| 340 | + var["role"] = "INPUT" |
| 341 | + out_var = [utils.astore.get_variable_properties(var) for var in result.OutputVariables.itertuples()] |
| 342 | + for var in out_var: |
| 343 | + if not var.get("role"): |
| 344 | + var["role"] = "OUTPUT" |
| 345 | + project = _create_project(project, model_props, repo_obj, in_var, out_var) |
311 | 346 | else:
|
312 |
| - model_props = {} |
313 |
| - |
314 |
| - project = _create_project(project, model_props, repo_obj, invar, outvar) |
315 |
| - |
316 |
| - model = mr.import_model_from_zip(name, project, zipfile, version=version) |
| 347 | + project = mr.get_project(project) |
| 348 | + astore = conn.astore.download(rstore=model) |
| 349 | + params = { |
| 350 | + "name": name, |
| 351 | + "projectId": project.id, |
| 352 | + "type": "ASTORE", |
| 353 | + "versionOption": version |
| 354 | + } |
| 355 | + model = mr.post( |
| 356 | + "/models", |
| 357 | + files={"files": ("{}.sasast".format(model.params["name"]), astore["blob"])}, |
| 358 | + data=params |
| 359 | + ) |
317 | 360 | return model
|
318 | 361 |
|
319 | 362 | # If the model is an scikit-learn model, generate the model dictionary
|
|
0 commit comments