Skip to content

Commit 53f82e3

Browse files
committed
Handle import to different project versions #107
1 parent bcb3e7a commit 53f82e3

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

src/sasctl/_services/model_repository.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class ModelRepository(Service):
5454

5555
@classmethod
5656
def get_astore(cls, model):
57-
"""Get the ASTORE for a model registered int he model repository.
57+
"""Get the ASTORE for a model registered in the model repository.
5858
5959
Parameters
6060
----------
@@ -503,6 +503,8 @@ def import_model_from_zip(
503503
The ZIP file containing the model and contents.
504504
description : str
505505
The description of the model.
506+
version : str, optional
507+
Name of the project version. Default value is "latest".
506508
507509
Returns
508510
-------
@@ -769,3 +771,47 @@ def get_model_details(cls, model):
769771
id_ = model["id"]
770772

771773
return cls.get("/models/%s" % id_)
774+
775+
@classmethod
776+
def list_project_versions(cls, project):
777+
"""_summary_
778+
779+
Parameters
780+
----------
781+
project : str or dict
782+
The name or id of the model project, or a dictionary representation
783+
of the model project.
784+
785+
Returns
786+
-------
787+
list of dicts
788+
List of dicts representing different project versions. Dict key/value
789+
pairs are as follows.
790+
name : str
791+
id : str
792+
number : str
793+
modified : datetime
794+
795+
"""
796+
from datetime import datetime
797+
798+
project_info = cls.get_project(project)
799+
800+
if project_info is None:
801+
raise ValueError("Project `%s` could not be found." % str(project))
802+
803+
projectVersions = cls.get(
804+
"/projects/{}/projectVersions".format(project_info.id)
805+
)
806+
versionList = []
807+
for version in projectVersions:
808+
versionDict = {
809+
"name": version.name,
810+
"id": version.id,
811+
"number": version.versionNumber,
812+
"modified": datetime.strptime(
813+
version.modifiedTimeStamp, "%Y-%m-%dT%H:%M:%S.%fZ"
814+
),
815+
}
816+
versionList.append(versionDict)
817+
return versionList

src/sasctl/pzmm/importModel.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,21 @@ def project_exists(response, project):
4949
return response
5050

5151

52-
def model_exists(project, name, force):
53-
"""Checks if model already exists and either raises an error or deletes the redundant model.
52+
def model_exists(project, name, force, versionName="latest"):
53+
"""Checks if model already exists in the same project and either raises an error or deletes
54+
the redundant model. If no project version is provided, the version is assumed to be "latest".
5455
5556
Parameters
5657
----------
57-
project : string or dict
58+
project : str or dict
5859
The name or id of the model project, or a dictionary representation of the project.
5960
name : str or dict
6061
The name of the model.
6162
force : bool, optional
6263
Sets whether to overwrite models with the same name upon upload.
64+
versionName : str, optional
65+
Name of project version to check if a model of the same name already exists. Default
66+
value is "latest".
6367
6468
Raises
6569
------
@@ -69,7 +73,19 @@ def model_exists(project, name, force):
6973
"""
7074
project = mr.get_project(project)
7175
projectId = project["id"]
72-
projectModels = mr.get("/projects/{}/models".format(projectId))
76+
projectVersions = mr.list_project_versions(project)
77+
if versionName == "latest":
78+
modTime = [item["modified"] for item in projectVersions]
79+
latestVersion = modTime.index(max(modTime))
80+
versionId = projectVersions[latestVersion]["id"]
81+
else:
82+
for version in projectVersions:
83+
if versionName is version["name"]:
84+
versionId = version["id"]
85+
break
86+
projectModels = mr.get(
87+
"/projects/{}/projectVersions/{}/models".format(projectId, versionId)
88+
)
7389

7490
for model in projectModels:
7591
# Throws a TypeError if only one model is in the project
@@ -106,6 +122,7 @@ def pzmmImportModel(
106122
targetDF,
107123
predictmethod,
108124
metrics=["EM_EVENTPROBABILITY", "EM_CLASSIFICATION"],
125+
projectVersion="latest",
109126
modelFileName=None,
110127
pyPath=None,
111128
threshPrediction=None,
@@ -156,9 +173,12 @@ def pzmmImportModel(
156173
metrics : string list, optional
157174
The scoring metrics for the model. The default is a set of two
158175
metrics: EM_EVENTPROBABILITY and EM_CLASSIFICATION.
176+
projectVersion : str, optional
177+
Name of project version to check if a model of the same name already exists. Default
178+
value is "latest".
159179
modelFileName : string, optional
160180
Name of the model file that contains the model. By default None and assigned as
161-
model_prefix + '.pickle'.
181+
modelPrefix + '.pickle'.
162182
pyPath : string, optional
163183
The local path of the score code file. By default None and assigned as the zPath.
164184
threshPrediction : float, optional
@@ -265,7 +285,9 @@ def getFiles(extensions):
265285
# Check if model with same name already exists in project.
266286
model_exists(project, modelPrefix, force)
267287

268-
response = mr.import_model_from_zip(modelPrefix, project, zipIOFile)
288+
response = mr.import_model_from_zip(
289+
modelPrefix, project, zipIOFile, version=projectVersion
290+
)
269291
try:
270292
print(
271293
"Model was successfully imported into SAS Model Manager as {} with UUID: {}.".format(
@@ -286,7 +308,9 @@ def getFiles(extensions):
286308
# Check if model with same name already exists in project.
287309
model_exists(project, modelPrefix, force)
288310

289-
response = mr.import_model_from_zip(modelPrefix, project, zipIOFile, force)
311+
response = mr.import_model_from_zip(
312+
modelPrefix, project, zipIOFile, force, version=projectVersion
313+
)
290314
try:
291315
print(
292316
"Model was successfully imported into SAS Model Manager as {} with UUID: {}.".format(

0 commit comments

Comments
 (0)