Skip to content

Commit 1660da9

Browse files
Publish model versions - added another unit test (#221)
* adding model versioning functionality for score testing, publishing, and performance monitoring models * removed debug print statements after testing versioning logic * Added unit tests for model version functionality * Refactored code for more clarity Signed-off-by: samyarpotlapalli <[email protected]> * Refactored code and DCO Remediation Commit for samyarpotlapalli <[email protected]>I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: c01034e23bd6bd9c3905648223b6d784f68e0693I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: b6078ce3560dcbe18cff09737491861f82d37f78I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: d32a9c4a26513d4a36d9ce45fa2d511ba5cc6f73Signed-off-by: samyarpotlapalli <[email protected]> Signed-off-by: samyarpotlapalli <[email protected]> * DCO Remediation Commit for samyarpotlapalli <[email protected]>I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: c01034e23bd6bd9c3905648223b6d784f68e0693I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: b6078ce3560dcbe18cff09737491861f82d37f78I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: d32a9c4a26513d4a36d9ce45fa2d511ba5cc6f73Signed-off-by: samyarpotlapalli <[email protected]> * Added another unit test for more coverage Signed-off-by: samyarpotlapalli <[email protected]> * Additional refactoring Signed-off-by: samyarpotlapalli <[email protected]> --------- Signed-off-by: samyarpotlapalli <[email protected]>
1 parent ef47a70 commit 1660da9

File tree

5 files changed

+440
-125
lines changed

5 files changed

+440
-125
lines changed

src/sasctl/_services/model_management.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ class ModelManagement(Service):
2828
# TODO: set ds2MultiType
2929
@classmethod
3030
def publish_model(
31-
cls, model, destination, name=None, force=False, reload_model_table=False
31+
cls,
32+
model,
33+
destination,
34+
model_version="latest",
35+
name=None,
36+
force=False,
37+
reload_model_table=False,
3238
):
3339
"""
3440
@@ -38,6 +44,8 @@ def publish_model(
3844
The name or id of the model, or a dictionary representation of the model.
3945
destination : str
4046
Name of destination to publish the model to.
47+
model_version : str or dict, optional
48+
Provide the version id, name, or dict to publish. Defaults to 'latest'.
4149
name : str, optional
4250
Provide a custom name for the published model. Defaults to None.
4351
force : bool, optional
@@ -68,6 +76,23 @@ def publish_model(
6876

6977
# TODO: Verify allowed formats by destination type.
7078
# As of 19w04 MAS throws HTTP 500 if name is in invalid format.
79+
if model_version != "latest":
80+
if isinstance(model_version, dict) and "modelVersionName" in model_version:
81+
model_version_name = model_version["modelVersionName"]
82+
elif (
83+
isinstance(model_version, dict)
84+
and "modelVersionName" not in model_version
85+
):
86+
raise ValueError("Model version is not recognized.")
87+
elif isinstance(model_version, str) and cls.is_uuid(model_version):
88+
model_version_name = mr.get_model_or_version(model, model_version)[
89+
"modelVersionName"
90+
]
91+
else:
92+
model_version_name = model_version
93+
else:
94+
model_version_name = ""
95+
7196
model_name = name or "{}_{}".format(
7297
model_obj["name"].replace(" ", ""), model_obj["id"]
7398
).replace("-", "")
@@ -79,6 +104,7 @@ def publish_model(
79104
{
80105
"modelName": mp._publish_name(model_name),
81106
"sourceUri": model_uri.get("uri"),
107+
"modelVersionID": model_version_name,
82108
"publishLevel": "model",
83109
}
84110
],
@@ -104,6 +130,7 @@ def create_performance_definition(
104130
table_prefix,
105131
project=None,
106132
models=None,
133+
modelVersions=None,
107134
library_name="Public",
108135
name=None,
109136
description=None,
@@ -136,6 +163,8 @@ def create_performance_definition(
136163
The name or id of the model(s), or a dictionary representation of the model(s). For
137164
multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all
138165
models in the project specified will be used. Defaults to None.
166+
modelVersions: str, list, optional
167+
The name of the model version(s). Defaults to None, so all models are latest.
139168
library_name : str
140169
The library containing the input data, default is 'Public'.
141170
name : str, optional
@@ -239,10 +268,13 @@ def create_performance_definition(
239268
"property set." % project.name
240269
)
241270

271+
# Creating the new array of modelIds with version names appended
272+
updated_models = cls.check_model_versions(models, modelVersions)
273+
242274
request = {
243275
"projectId": project.id,
244276
"name": name or project.name + " Performance",
245-
"modelIds": [model.id for model in models],
277+
"modelIds": updated_models,
246278
"championMonitored": monitor_champion,
247279
"challengerMonitored": monitor_challenger,
248280
"maxBins": max_bins,
@@ -279,7 +311,6 @@ def create_performance_definition(
279311
for v in project.get("variables", [])
280312
if v.get("role") == "output"
281313
]
282-
283314
return cls.post(
284315
"/performanceTasks",
285316
json=request,
@@ -288,6 +319,57 @@ def create_performance_definition(
288319
},
289320
)
290321

322+
@classmethod
323+
def check_model_versions(cls, models, modelVersions):
324+
"""
325+
Checking if the model version(s) are valid and append to model id accordingly.
326+
327+
Parameters
328+
----------
329+
models: list of str
330+
List of models.
331+
modelVersions : list of str
332+
List of model versions associated with models.
333+
334+
Returns
335+
-------
336+
String list
337+
"""
338+
if not modelVersions:
339+
return [model.id for model in models]
340+
341+
updated_models = []
342+
if not isinstance(modelVersions, list):
343+
modelVersions = [modelVersions]
344+
345+
if len(models) < len(modelVersions):
346+
raise ValueError(
347+
"There are too many versions for the amount of models specified."
348+
)
349+
350+
modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
351+
for model, modelVersionName in zip(models, modelVersions):
352+
353+
if (
354+
isinstance(modelVersionName, dict)
355+
and "modelVersionName" in modelVersionName
356+
):
357+
358+
modelVersionName = modelVersionName["modelVersionName"]
359+
elif (
360+
isinstance(modelVersionName, dict)
361+
and "modelVersionName" not in modelVersionName
362+
):
363+
364+
raise ValueError("Model version is not recognized.")
365+
366+
if modelVersionName != "":
367+
updated_models.append(model.id + ":" + modelVersionName)
368+
else:
369+
updated_models.append(model.id)
370+
371+
return updated_models
372+
291373
@classmethod
292374
def execute_performance_definition(cls, definition):
293375
"""Launches a job to run a performance definition.

src/sasctl/_services/model_publish.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .model_repository import ModelRepository
1212
from .service import Service
13+
from ..utils.decorators import deprecated
1314

1415

1516
class ModelPublish(Service):
@@ -90,7 +91,7 @@ def delete_destination(cls, item):
9091

9192
return cls.delete("/destinations/{name}".format(name=item))
9293

93-
@classmethod
94+
@deprecated("Use publish_model in model_management.py instead.", "1.11.5")
9495
def publish_model(cls, model, destination, name=None, code=None, notes=None):
9596
"""Publish a model to an existing publishing destination.
9697

src/sasctl/_services/score_definitions.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_score_definition(
4646
description: str = "",
4747
server_name: str = "cas-shared-default",
4848
library_name: str = "Public",
49-
model_version: str = "latest",
49+
model_version: Union[str, dict] = "latest",
5050
):
5151
"""Creates the score definition service.
5252
@@ -69,7 +69,7 @@ def create_score_definition(
6969
library_name: str, optional
7070
The library within the CAS server the table exists in. Defaults to "Public".
7171
model_version: str, optional
72-
The user-chosen version of the model with the specified model_id. Defaults to "latest".
72+
The user-chosen version of the model. Deafaults to "latest".
7373
7474
Returns
7575
-------
@@ -116,7 +116,7 @@ def create_score_definition(
116116
table = cls._cas_management.get_table(table_name, library_name, server_name)
117117
if not table and not table_file:
118118
raise HTTPError(
119-
f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist."
119+
"This table may not exist in CAS. Include the `table_file` argument."
120120
)
121121
elif not table and table_file:
122122
cls._cas_management.upload_file(
@@ -125,16 +125,19 @@ def create_score_definition(
125125
table = cls._cas_management.get_table(table_name, library_name, server_name)
126126
if not table:
127127
raise HTTPError(
128-
f"The file failed to upload properly or another error occurred."
128+
"The file failed to upload properly or another error occurred."
129129
)
130130
# Checks if the inputted table exists, and if not, uploads a file to create a new table
131131

132+
object_uri, model_version = cls.check_model_version(model_id, model_version)
133+
# Checks if the model version is valid and how to find the name
134+
132135
save_score_def = {
133136
"name": model_name, # used to be score_def_name
134137
"description": description,
135138
"objectDescriptor": {
136-
"uri": f"/modelManagement/models/{model_id}",
137-
"name": f"{model_name}({model_version})",
139+
"uri": object_uri,
140+
"name": f"{model_name} ({model_version})",
138141
"type": f"{object_descriptor_type}",
139142
},
140143
"inputData": {
@@ -149,7 +152,7 @@ def create_score_definition(
149152
"projectUri": f"/modelRepository/projects/{model_project_id}",
150153
"projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}",
151154
"publishDestination": "",
152-
"versionedModel": f"{model_name}({model_version})",
155+
"versionedModel": f"{model_name} ({model_version})",
153156
},
154157
"mappings": inputMapping,
155158
}
@@ -161,3 +164,37 @@ def create_score_definition(
161164
"/definitions", data=json.dumps(save_score_def), headers=headers_score_def
162165
)
163166
# The response information of the score definition can be seen as a JSON as well as a RestOBJ
167+
168+
@classmethod
169+
def check_model_version(cls, model_id: str, model_version: Union[str, dict]):
170+
"""Checks if the model version is valid.
171+
172+
Parameters
173+
----------
174+
model_version : str or dict
175+
The model version to check.
176+
177+
Returns
178+
-------
179+
String tuple
180+
"""
181+
if model_version != "latest":
182+
183+
if isinstance(model_version, dict) and "modelVersionName" in model_version:
184+
model_version = model_version["modelVersionName"]
185+
elif (
186+
isinstance(model_version, dict)
187+
and "modelVersionName" not in model_version
188+
):
189+
raise ValueError("Model version cannot be found.")
190+
elif isinstance(model_version, str) and cls.is_uuid(model_version):
191+
model_version = cls._model_repository.get_model_or_version(
192+
model_id, model_version
193+
)["modelVersionName"]
194+
195+
object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}"
196+
197+
else:
198+
object_uri = f"/modelManagement/models/{model_id}"
199+
200+
return object_uri, model_version

0 commit comments

Comments
 (0)