Skip to content

Commit c01034e

Browse files
adding model versioning functionality for score testing, publishing, and performance monitoring models
1 parent ef47a70 commit c01034e

File tree

3 files changed

+78
-9
lines changed

3 files changed

+78
-9
lines changed

src/sasctl/_services/model_management.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ class ModelManagement(Service):
2828
# TODO: set ds2MultiType
2929
@classmethod
3030
def publish_model(
31-
cls, model, destination, name=None, force=False, reload_model_table=False
31+
cls,
32+
model,
33+
destination,
34+
model_version="latest",
35+
name=None,
36+
force=False,
37+
reload_model_table=False,
3238
):
3339
"""
3440
@@ -38,6 +44,8 @@ def publish_model(
3844
The name or id of the model, or a dictionary representation of the model.
3945
destination : str
4046
Name of destination to publish the model to.
47+
model_version_id : str or dict, optional
48+
Provide the id, name, or dictionary representation of the version to publish. Defaults to 'latest'.
4149
name : str, optional
4250
Provide a custom name for the published model. Defaults to None.
4351
force : bool, optional
@@ -68,6 +76,18 @@ def publish_model(
6876

6977
# TODO: Verify allowed formats by destination type.
7078
# As of 19w04 MAS throws HTTP 500 if name is in invalid format.
79+
if model_version != "latest":
80+
if isinstance(model_version, dict) and "modelVersionName" in model_version:
81+
model_version_name = model_version["modelVersionName"]
82+
elif isinstance(model_version, str) and cls.is_uuid(model_version):
83+
model_version_name = mr.get_model_or_version(model, model_version)[
84+
"modelVersionName"
85+
]
86+
else:
87+
model_version_name = model_version
88+
else:
89+
model_version_name = ""
90+
7191
model_name = name or "{}_{}".format(
7292
model_obj["name"].replace(" ", ""), model_obj["id"]
7393
).replace("-", "")
@@ -79,6 +99,7 @@ def publish_model(
7999
{
80100
"modelName": mp._publish_name(model_name),
81101
"sourceUri": model_uri.get("uri"),
102+
"modelVersionID": model_version_name,
82103
"publishLevel": "model",
83104
}
84105
],
@@ -104,6 +125,7 @@ def create_performance_definition(
104125
table_prefix,
105126
project=None,
106127
models=None,
128+
modelVersions=None,
107129
library_name="Public",
108130
name=None,
109131
description=None,
@@ -136,6 +158,9 @@ def create_performance_definition(
136158
The name or id of the model(s), or a dictionary representation of the model(s). For
137159
multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all
138160
models in the project specified will be used. Defaults to None.
161+
modelVersions: str, list, optional
162+
The name of the model version(s) for models used in the performance definition. If no model versions
163+
are specified, all models will use the latest version. Defaults to None.
139164
library_name : str
140165
The library containing the input data, default is 'Public'.
141166
name : str, optional
@@ -239,10 +264,37 @@ def create_performance_definition(
239264
"property set." % project.name
240265
)
241266

267+
if not modelVersions:
268+
updated_models = [model.id for model in models]
269+
else:
270+
updated_models = []
271+
if not isinstance(modelVersions, list):
272+
modelVersions = [modelVersions]
273+
274+
if len(models) < len(modelVersions):
275+
raise ValueError(
276+
"There are too many versions for the amount of models specified."
277+
)
278+
279+
modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
280+
for model, modelVersionName in zip(models, modelVersions):
281+
print(model.name)
282+
if (
283+
isinstance(modelVersionName, dict)
284+
and "modelVersionName" in modelVersionName
285+
):
286+
modelVersionName = modelVersionName["modelVersionName"]
287+
elif (
288+
isinstance(modelVersionName, dict)
289+
and "modelVersionName" not in modelVersionName
290+
):
291+
raise ValueError("Model version is not recognized.")
292+
updated_models.append(model.id + ":" + modelVersionName)
293+
242294
request = {
243295
"projectId": project.id,
244296
"name": name or project.name + " Performance",
245-
"modelIds": [model.id for model in models],
297+
"modelIds": [model for model in updated_models],
246298
"championMonitored": monitor_champion,
247299
"challengerMonitored": monitor_challenger,
248300
"maxBins": max_bins,

src/sasctl/_services/model_publish.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .model_repository import ModelRepository
1212
from .service import Service
13+
from ..utils.decorators import deprecated
1314

1415

1516
class ModelPublish(Service):
@@ -90,7 +91,7 @@ def delete_destination(cls, item):
9091

9192
return cls.delete("/destinations/{name}".format(name=item))
9293

93-
@classmethod
94+
@deprecated("Use publish_model in model_management.py instead.", "1.11.5")
9495
def publish_model(cls, model, destination, name=None, code=None, notes=None):
9596
"""Publish a model to an existing publishing destination.
9697

src/sasctl/_services/score_definitions.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_score_definition(
6969
library_name: str, optional
7070
The library within the CAS server the table exists in. Defaults to "Public".
7171
model_version: str, optional
72-
The user-chosen version of the model with the specified model_id. Defaults to "latest".
72+
The user-chosen version of the model with the specified model version name. Defaults to latest version.
7373
7474
Returns
7575
-------
@@ -116,7 +116,7 @@ def create_score_definition(
116116
table = cls._cas_management.get_table(table_name, library_name, server_name)
117117
if not table and not table_file:
118118
raise HTTPError(
119-
f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist."
119+
"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist."
120120
)
121121
elif not table and table_file:
122122
cls._cas_management.upload_file(
@@ -125,16 +125,32 @@ def create_score_definition(
125125
table = cls._cas_management.get_table(table_name, library_name, server_name)
126126
if not table:
127127
raise HTTPError(
128-
f"The file failed to upload properly or another error occurred."
128+
"The file failed to upload properly or another error occurred."
129129
)
130130
# Checks if the inputted table exists, and if not, uploads a file to create a new table
131131

132+
if model_version != "latest":
133+
134+
if isinstance(model_version, dict) and "modelVersionName" in model_version:
135+
model_version = model_version["modelVersionName"]
136+
elif isinstance(model_version, str) and cls.is_uuid(model_version):
137+
model_version = cls._model_repository.get_model_or_version(
138+
model_id, model_version
139+
)["modelVersionName"]
140+
else:
141+
model_version = model_version
142+
143+
object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}"
144+
145+
else:
146+
object_uri = f"/modelManagement/models/{model_id}"
147+
132148
save_score_def = {
133149
"name": model_name, # used to be score_def_name
134150
"description": description,
135151
"objectDescriptor": {
136-
"uri": f"/modelManagement/models/{model_id}",
137-
"name": f"{model_name}({model_version})",
152+
"uri": object_uri,
153+
"name": f"{model_name} ({model_version})",
138154
"type": f"{object_descriptor_type}",
139155
},
140156
"inputData": {
@@ -149,7 +165,7 @@ def create_score_definition(
149165
"projectUri": f"/modelRepository/projects/{model_project_id}",
150166
"projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}",
151167
"publishDestination": "",
152-
"versionedModel": f"{model_name}({model_version})",
168+
"versionedModel": f"{model_name} ({model_version})",
153169
},
154170
"mappings": inputMapping,
155171
}

0 commit comments

Comments
 (0)