Skip to content

Refactored more code #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
88 changes: 85 additions & 3 deletions src/sasctl/_services/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ class ModelManagement(Service):
# TODO: set ds2MultiType
@classmethod
def publish_model(
cls, model, destination, name=None, force=False, reload_model_table=False
cls,
model,
destination,
model_version="latest",
name=None,
force=False,
reload_model_table=False,
):
"""

Expand All @@ -38,6 +44,8 @@ def publish_model(
The name or id of the model, or a dictionary representation of the model.
destination : str
Name of destination to publish the model to.
model_version : str or dict, optional
Provide the version id, name, or dict to publish. Defaults to 'latest'.
name : str, optional
Provide a custom name for the published model. Defaults to None.
force : bool, optional
Expand Down Expand Up @@ -68,6 +76,23 @@ def publish_model(

# TODO: Verify allowed formats by destination type.
# As of 19w04 MAS throws HTTP 500 if name is in invalid format.
if model_version != "latest":
if isinstance(model_version, dict) and "modelVersionName" in model_version:
model_version_name = model_version["modelVersionName"]
elif (
isinstance(model_version, dict)
and "modelVersionName" not in model_version
):
raise ValueError("Model version is not recognized.")
elif isinstance(model_version, str) and cls.is_uuid(model_version):
model_version_name = mr.get_model_or_version(model, model_version)[
"modelVersionName"
]
else:
model_version_name = model_version
else:
model_version_name = ""

model_name = name or "{}_{}".format(
model_obj["name"].replace(" ", ""), model_obj["id"]
).replace("-", "")
Expand All @@ -79,6 +104,7 @@ def publish_model(
{
"modelName": mp._publish_name(model_name),
"sourceUri": model_uri.get("uri"),
"modelVersionID": model_version_name,
"publishLevel": "model",
}
],
Expand All @@ -104,6 +130,7 @@ def create_performance_definition(
table_prefix,
project=None,
models=None,
modelVersions=None,
library_name="Public",
name=None,
description=None,
Expand Down Expand Up @@ -136,6 +163,8 @@ def create_performance_definition(
The name or id of the model(s), or a dictionary representation of the model(s). For
multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all
models in the project specified will be used. Defaults to None.
modelVersions: str, list, optional
The name of the model version(s). Defaults to None, so all models are latest.
library_name : str
The library containing the input data, default is 'Public'.
name : str, optional
Expand Down Expand Up @@ -239,10 +268,13 @@ def create_performance_definition(
"property set." % project.name
)

# Creating the new array of modelIds with version names appended
updated_models = cls.check_model_versions(models, modelVersions)

request = {
"projectId": project.id,
"name": name or project.name + " Performance",
"modelIds": [model.id for model in models],
"modelIds": updated_models,
"championMonitored": monitor_champion,
"challengerMonitored": monitor_challenger,
"maxBins": max_bins,
Expand Down Expand Up @@ -279,7 +311,6 @@ def create_performance_definition(
for v in project.get("variables", [])
if v.get("role") == "output"
]

return cls.post(
"/performanceTasks",
json=request,
Expand All @@ -288,6 +319,57 @@ def create_performance_definition(
},
)

@classmethod
def check_model_versions(cls, models, modelVersions):
"""
Checking if the model version(s) are valid and append to model id accordingly.

Parameters
----------
models: list of str
List of models.
modelVersions : list of str
List of model versions associated with models.

Returns
-------
String list
"""
if not modelVersions:
return [model.id for model in models]

updated_models = []
if not isinstance(modelVersions, list):
modelVersions = [modelVersions]

if len(models) < len(modelVersions):
raise ValueError(
"There are too many versions for the amount of models specified."
)

modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
for model, modelVersionName in zip(models, modelVersions):

if (
isinstance(modelVersionName, dict)
and "modelVersionName" in modelVersionName
):

modelVersionName = modelVersionName["modelVersionName"]
elif (
isinstance(modelVersionName, dict)
and "modelVersionName" not in modelVersionName
):

raise ValueError("Model version is not recognized.")

if modelVersionName != "":
updated_models.append(model.id + ":" + modelVersionName)
else:
updated_models.append(model.id)

return updated_models

@classmethod
def execute_performance_definition(cls, definition):
"""Launches a job to run a performance definition.
Expand Down
3 changes: 2 additions & 1 deletion src/sasctl/_services/model_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .model_repository import ModelRepository
from .service import Service
from ..utils.decorators import deprecated


class ModelPublish(Service):
Expand Down Expand Up @@ -90,7 +91,7 @@ def delete_destination(cls, item):

return cls.delete("/destinations/{name}".format(name=item))

@classmethod
@deprecated("Use publish_model in model_management.py instead.", "1.11.5")
def publish_model(cls, model, destination, name=None, code=None, notes=None):
"""Publish a model to an existing publishing destination.

Expand Down
51 changes: 44 additions & 7 deletions src/sasctl/_services/score_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_score_definition(
description: str = "",
server_name: str = "cas-shared-default",
library_name: str = "Public",
model_version: str = "latest",
model_version: Union[str, dict] = "latest",
):
"""Creates the score definition service.

Expand All @@ -69,7 +69,7 @@ def create_score_definition(
library_name: str, optional
The library within the CAS server the table exists in. Defaults to "Public".
model_version: str, optional
The user-chosen version of the model with the specified model_id. Defaults to "latest".
The user-chosen version of the model. Deafaults to "latest".

Returns
-------
Expand Down Expand Up @@ -116,7 +116,7 @@ def create_score_definition(
table = cls._cas_management.get_table(table_name, library_name, server_name)
if not table and not table_file:
raise HTTPError(
f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist."
"This table may not exist in CAS. Include the `table_file` argument."
)
elif not table and table_file:
cls._cas_management.upload_file(
Expand All @@ -125,16 +125,19 @@ def create_score_definition(
table = cls._cas_management.get_table(table_name, library_name, server_name)
if not table:
raise HTTPError(
f"The file failed to upload properly or another error occurred."
"The file failed to upload properly or another error occurred."
)
# Checks if the inputted table exists, and if not, uploads a file to create a new table

object_uri, model_version = cls.check_model_version(model_id, model_version)
# Checks if the model version is valid and how to find the name

save_score_def = {
"name": model_name, # used to be score_def_name
"description": description,
"objectDescriptor": {
"uri": f"/modelManagement/models/{model_id}",
"name": f"{model_name}({model_version})",
"uri": object_uri,
"name": f"{model_name} ({model_version})",
"type": f"{object_descriptor_type}",
},
"inputData": {
Expand All @@ -149,7 +152,7 @@ def create_score_definition(
"projectUri": f"/modelRepository/projects/{model_project_id}",
"projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}",
"publishDestination": "",
"versionedModel": f"{model_name}({model_version})",
"versionedModel": f"{model_name} ({model_version})",
},
"mappings": inputMapping,
}
Expand All @@ -161,3 +164,37 @@ def create_score_definition(
"/definitions", data=json.dumps(save_score_def), headers=headers_score_def
)
# The response information of the score definition can be seen as a JSON as well as a RestOBJ

@classmethod
def check_model_version(cls, model_id: str, model_version: Union[str, dict]):
"""Checks if the model version is valid.

Parameters
----------
model_version : str or dict
The model version to check.

Returns
-------
String tuple
"""
if model_version != "latest":

if isinstance(model_version, dict) and "modelVersionName" in model_version:
model_version = model_version["modelVersionName"]
elif (
isinstance(model_version, dict)
and "modelVersionName" not in model_version
):
raise ValueError("Model version cannot be found.")
elif isinstance(model_version, str) and cls.is_uuid(model_version):
model_version = cls._model_repository.get_model_or_version(
model_id, model_version
)["modelVersionName"]

object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}"

else:
object_uri = f"/modelManagement/models/{model_id}"

return object_uri, model_version
Loading
Loading