From c01034e23bd6bd9c3905648223b6d784f68e0693 Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Mon, 28 Jul 2025 17:23:24 -0400 Subject: [PATCH 1/8] adding model versioning functionality for score testing, publishing, and performance monitoring models --- src/sasctl/_services/model_management.py | 56 ++++++++++++++++++++++- src/sasctl/_services/model_publish.py | 3 +- src/sasctl/_services/score_definitions.py | 28 +++++++++--- 3 files changed, 78 insertions(+), 9 deletions(-) diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index e91b1aa6..966b5e88 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -28,7 +28,13 @@ class ModelManagement(Service): # TODO: set ds2MultiType @classmethod def publish_model( - cls, model, destination, name=None, force=False, reload_model_table=False + cls, + model, + destination, + model_version="latest", + name=None, + force=False, + reload_model_table=False, ): """ @@ -38,6 +44,8 @@ def publish_model( The name or id of the model, or a dictionary representation of the model. destination : str Name of destination to publish the model to. + model_version_id : str or dict, optional + Provide the id, name, or dictionary representation of the version to publish. Defaults to 'latest'. name : str, optional Provide a custom name for the published model. Defaults to None. force : bool, optional @@ -68,6 +76,18 @@ def publish_model( # TODO: Verify allowed formats by destination type. # As of 19w04 MAS throws HTTP 500 if name is in invalid format. + if model_version != "latest": + if isinstance(model_version, dict) and "modelVersionName" in model_version: + model_version_name = model_version["modelVersionName"] + elif isinstance(model_version, str) and cls.is_uuid(model_version): + model_version_name = mr.get_model_or_version(model, model_version)[ + "modelVersionName" + ] + else: + model_version_name = model_version + else: + model_version_name = "" + model_name = name or "{}_{}".format( model_obj["name"].replace(" ", ""), model_obj["id"] ).replace("-", "") @@ -79,6 +99,7 @@ def publish_model( { "modelName": mp._publish_name(model_name), "sourceUri": model_uri.get("uri"), + "modelVersionID": model_version_name, "publishLevel": "model", } ], @@ -104,6 +125,7 @@ def create_performance_definition( table_prefix, project=None, models=None, + modelVersions=None, library_name="Public", name=None, description=None, @@ -136,6 +158,9 @@ def create_performance_definition( The name or id of the model(s), or a dictionary representation of the model(s). For multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all models in the project specified will be used. Defaults to None. + modelVersions: str, list, optional + The name of the model version(s) for models used in the performance definition. If no model versions + are specified, all models will use the latest version. Defaults to None. library_name : str The library containing the input data, default is 'Public'. name : str, optional @@ -239,10 +264,37 @@ def create_performance_definition( "property set." % project.name ) + if not modelVersions: + updated_models = [model.id for model in models] + else: + updated_models = [] + if not isinstance(modelVersions, list): + modelVersions = [modelVersions] + + if len(models) < len(modelVersions): + raise ValueError( + "There are too many versions for the amount of models specified." + ) + + modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) + for model, modelVersionName in zip(models, modelVersions): + print(model.name) + if ( + isinstance(modelVersionName, dict) + and "modelVersionName" in modelVersionName + ): + modelVersionName = modelVersionName["modelVersionName"] + elif ( + isinstance(modelVersionName, dict) + and "modelVersionName" not in modelVersionName + ): + raise ValueError("Model version is not recognized.") + updated_models.append(model.id + ":" + modelVersionName) + request = { "projectId": project.id, "name": name or project.name + " Performance", - "modelIds": [model.id for model in models], + "modelIds": [model for model in updated_models], "championMonitored": monitor_champion, "challengerMonitored": monitor_challenger, "maxBins": max_bins, diff --git a/src/sasctl/_services/model_publish.py b/src/sasctl/_services/model_publish.py index c3fa225f..90f665ad 100644 --- a/src/sasctl/_services/model_publish.py +++ b/src/sasctl/_services/model_publish.py @@ -10,6 +10,7 @@ from .model_repository import ModelRepository from .service import Service +from ..utils.decorators import deprecated class ModelPublish(Service): @@ -90,7 +91,7 @@ def delete_destination(cls, item): return cls.delete("/destinations/{name}".format(name=item)) - @classmethod + @deprecated("Use publish_model in model_management.py instead.", "1.11.5") def publish_model(cls, model, destination, name=None, code=None, notes=None): """Publish a model to an existing publishing destination. diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index 2c05611f..ac0c7a5d 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -69,7 +69,7 @@ def create_score_definition( library_name: str, optional The library within the CAS server the table exists in. Defaults to "Public". model_version: str, optional - The user-chosen version of the model with the specified model_id. Defaults to "latest". + The user-chosen version of the model with the specified model version name. Defaults to latest version. Returns ------- @@ -116,7 +116,7 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table and not table_file: raise HTTPError( - f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist." + "This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist." ) elif not table and table_file: cls._cas_management.upload_file( @@ -125,16 +125,32 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table: raise HTTPError( - f"The file failed to upload properly or another error occurred." + "The file failed to upload properly or another error occurred." ) # Checks if the inputted table exists, and if not, uploads a file to create a new table + if model_version != "latest": + + if isinstance(model_version, dict) and "modelVersionName" in model_version: + model_version = model_version["modelVersionName"] + elif isinstance(model_version, str) and cls.is_uuid(model_version): + model_version = cls._model_repository.get_model_or_version( + model_id, model_version + )["modelVersionName"] + else: + model_version = model_version + + object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}" + + else: + object_uri = f"/modelManagement/models/{model_id}" + save_score_def = { "name": model_name, # used to be score_def_name "description": description, "objectDescriptor": { - "uri": f"/modelManagement/models/{model_id}", - "name": f"{model_name}({model_version})", + "uri": object_uri, + "name": f"{model_name} ({model_version})", "type": f"{object_descriptor_type}", }, "inputData": { @@ -149,7 +165,7 @@ def create_score_definition( "projectUri": f"/modelRepository/projects/{model_project_id}", "projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}", "publishDestination": "", - "versionedModel": f"{model_name}({model_version})", + "versionedModel": f"{model_name} ({model_version})", }, "mappings": inputMapping, } From b6078ce3560dcbe18cff09737491861f82d37f78 Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Tue, 29 Jul 2025 09:25:54 -0400 Subject: [PATCH 2/8] removed debug print statements after testing versioning logic --- src/sasctl/_services/model_management.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index 966b5e88..b73c66a9 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -278,7 +278,6 @@ def create_performance_definition( modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) for model, modelVersionName in zip(models, modelVersions): - print(model.name) if ( isinstance(modelVersionName, dict) and "modelVersionName" in modelVersionName From d32a9c4a26513d4a36d9ce45fa2d511ba5cc6f73 Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Mon, 4 Aug 2025 18:11:16 -0400 Subject: [PATCH 3/8] Added unit tests for model version functionality --- src/sasctl/_services/model_management.py | 19 +- src/sasctl/_services/score_definitions.py | 8 + tests/unit/test_model_management.py | 150 ++++++++++--- tests/unit/test_score_definitions.py | 249 ++++++++++++++-------- 4 files changed, 308 insertions(+), 118 deletions(-) diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index b73c66a9..e78d2549 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -79,6 +79,11 @@ def publish_model( if model_version != "latest": if isinstance(model_version, dict) and "modelVersionName" in model_version: model_version_name = model_version["modelVersionName"] + elif ( + isinstance(model_version, dict) + and "modelVersionName" not in model_version + ): + raise ValueError("Model version is not recognized.") elif isinstance(model_version, str) and cls.is_uuid(model_version): model_version_name = mr.get_model_or_version(model, model_version)[ "modelVersionName" @@ -263,7 +268,7 @@ def create_performance_definition( "Project %s must have the 'predictionVariable' " "property set." % project.name ) - + print("sup") if not modelVersions: updated_models = [model.id for model in models] else: @@ -278,22 +283,29 @@ def create_performance_definition( modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) for model, modelVersionName in zip(models, modelVersions): + if ( isinstance(modelVersionName, dict) and "modelVersionName" in modelVersionName ): + modelVersionName = modelVersionName["modelVersionName"] elif ( isinstance(modelVersionName, dict) and "modelVersionName" not in modelVersionName ): + raise ValueError("Model version is not recognized.") - updated_models.append(model.id + ":" + modelVersionName) + + if modelVersionName != "": + updated_models.append(model.id + ":" + modelVersionName) + else: + updated_models.append(model.id) request = { "projectId": project.id, "name": name or project.name + " Performance", - "modelIds": [model for model in updated_models], + "modelIds": updated_models, "championMonitored": monitor_champion, "challengerMonitored": monitor_challenger, "maxBins": max_bins, @@ -330,7 +342,6 @@ def create_performance_definition( for v in project.get("variables", []) if v.get("role") == "output" ] - return cls.post( "/performanceTasks", json=request, diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index ac0c7a5d..f37cfb6b 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -133,7 +133,15 @@ def create_score_definition( if isinstance(model_version, dict) and "modelVersionName" in model_version: model_version = model_version["modelVersionName"] + elif ( + isinstance(model_version, dict) + and "modelVersionName" not in model_version + ): + raise ValueError( + "Model version cannot be found. Please check the inputted model version." + ) elif isinstance(model_version, str) and cls.is_uuid(model_version): + print("hello") model_version = cls._model_repository.get_model_or_version( model_id, model_version )["modelVersionName"] diff --git a/tests/unit/test_model_management.py b/tests/unit/test_model_management.py index fbd4fc36..834b0ecc 100644 --- a/tests/unit/test_model_management.py +++ b/tests/unit/test_model_management.py @@ -23,6 +23,8 @@ def test_create_performance_definition(): RestObj({"name": "Test Model 2", "id": "67890", "projectId": PROJECT["id"]}), ] USER = "username" + VERSION_MOCK = {"modelVersionName": "1.0"} + VERSION_MOCK_NONAME = {} with mock.patch("sasctl.core.Session._get_authorization_token"): current_session("example.com", USER, "password") @@ -111,6 +113,32 @@ def test_create_performance_definition(): table_prefix="TestData", ) + with pytest.raises(ValueError): + # Model verions exceeds models + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=["1.0", "2.0", "3.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + with pytest.raises(ValueError): + # Model version dictionary missing modelVersionName + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=VERSION_MOCK_NONAME, + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + get_project.return_value = copy.deepcopy(PROJECT) get_project.return_value["targetVariable"] = "target" get_project.return_value["targetLevel"] = "interval" @@ -125,21 +153,68 @@ def test_create_performance_definition(): monitor_challenger=True, monitor_champion=True, ) + url, data = post_models.call_args + assert post_models.call_count == 1 + assert PROJECT["id"] == data["json"]["projectId"] + assert MODELS[0]["id"] in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + assert "TestLibrary" == data["json"]["dataLibrary"] + assert "TestData" == data["json"]["dataPrefix"] + assert "cas-shared-default" == data["json"]["casServerId"] + assert data["json"]["name"] + assert data["json"]["description"] + assert data["json"]["maxBins"] == 3 + assert data["json"]["championMonitored"] is True + assert data["json"]["challengerMonitored"] is True - assert post_models.call_count == 1 - url, data = post_models.call_args - - assert PROJECT["id"] == data["json"]["projectId"] - assert MODELS[0]["id"] in data["json"]["modelIds"] - assert MODELS[1]["id"] in data["json"]["modelIds"] - assert "TestLibrary" == data["json"]["dataLibrary"] - assert "TestData" == data["json"]["dataPrefix"] - assert "cas-shared-default" == data["json"]["casServerId"] - assert data["json"]["name"] - assert data["json"]["description"] - assert data["json"]["maxBins"] == 3 - assert data["json"]["championMonitored"] is True - assert data["json"]["challengerMonitored"] is True + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + # One model version as a string name + models=["model1", "model2"], + modelVersions="1.0", + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + assert post_models.call_count == 2 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + + get_model.side_effect = copy.deepcopy(MODELS) + # List of string type model versions + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=["1.0", "2.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + assert post_models.call_count == 3 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"] + + get_model.side_effect = copy.deepcopy(MODELS) + # List of dictionary type and string type model versions + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=[VERSION_MOCK, "2.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + assert post_models.call_count == 4 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"] with mock.patch( "sasctl._services.model_management.ModelManagement" ".post" @@ -160,20 +235,39 @@ def test_create_performance_definition(): monitor_champion=True, ) - assert post_project.call_count == 1 - url, data = post_project.call_args - - assert PROJECT["id"] == data["json"]["projectId"] - assert MODELS[0]["id"] in data["json"]["modelIds"] - assert MODELS[1]["id"] in data["json"]["modelIds"] - assert "TestLibrary" == data["json"]["dataLibrary"] - assert "TestData" == data["json"]["dataPrefix"] - assert "cas-shared-default" == data["json"]["casServerId"] - assert data["json"]["name"] - assert data["json"]["description"] - assert data["json"]["maxBins"] == 3 - assert data["json"]["championMonitored"] is True - assert data["json"]["challengerMonitored"] is True + # one extra test for project with version id + + assert post_project.call_count == 1 + url, data = post_project.call_args + + assert PROJECT["id"] == data["json"]["projectId"] + assert MODELS[0]["id"] in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + assert "TestLibrary" == data["json"]["dataLibrary"] + assert "TestData" == data["json"]["dataPrefix"] + assert "cas-shared-default" == data["json"]["casServerId"] + assert data["json"]["name"] + assert data["json"]["description"] + assert data["json"]["maxBins"] == 3 + assert data["json"]["championMonitored"] is True + assert data["json"]["challengerMonitored"] is True + + get_model.side_effect = copy.deepcopy(MODELS) + # Project with model version + _ = mm.create_performance_definition( + project="project", + modelVersions="2.0", + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + assert post_project.call_count == 2 + url, data = post_project.call_args + assert f"{MODELS[0]['id']}:2.0" in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] def test_table_prefix_format(): with pytest.raises(ValueError): diff --git a/tests/unit/test_score_definitions.py b/tests/unit/test_score_definitions.py index d1210866..075b316c 100644 --- a/tests/unit/test_score_definitions.py +++ b/tests/unit/test_score_definitions.py @@ -63,89 +63,166 @@ def test_create_score_definition(): "sasctl._services.cas_management.CASManagement.upload_file" ) as upload_file: with mock.patch( - "sasctl._services.score_definitions.ScoreDefinitions.post" - ) as post: - # Invalid model id test case - get_model.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - # Valid model id but invalid table name with no table_file argument test case - get_model_mock = { - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - } - get_model.return_value = get_model_mock - get_table.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - - # Invalid table name with a table_file argument that doesn't work test case - get_table.return_value = None - upload_file.return_value = None - get_table.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - - # Valid table_file argument that successfully creates a table test case - get_table.return_value = None - upload_file.return_value = RestObj - get_table_mock = {"tableName": "test_table"} - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - assert response - - # Valid table_name argument test case - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - assert response - - # Checking response with inputVariables in model elements - get_model_mock = { - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - "inputVariables": [ - {"name": "first"}, - {"name": "second"}, - {"name": "third"}, - ], - } - get_model.return_value = get_model_mock - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - assert response - assert post.call_count == 3 - - data = post.call_args - json_data = json.loads(data.kwargs["data"]) - assert json_data["mappings"] != [] + "sasctl._services.model_repository.ModelRepository.get_model_or_version" + ) as get_model_or_version: + with mock.patch( + "sasctl._services.score_definitions.ScoreDefinitions.is_uuid" + ) as is_uuid: + with mock.patch( + "sasctl._services.score_definitions.ScoreDefinitions.post" + ) as post: + + # Invalid model id test case + get_model.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + # Valid model id but invalid table name with no table_file argument test case + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + } + get_model.return_value = get_model_mock + get_table.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + + # Invalid table name with a table_file argument that doesn't work test case + get_table.return_value = None + upload_file.return_value = None + get_table.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + + # Valid table_file argument that successfully creates a table test case + get_table.return_value = None + upload_file.return_value = RestObj + get_table_mock = {"tableName": "test_table"} + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + assert response + + # Valid table_name argument test case + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + assert response + + # Checking response with inputVariables in model elements + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + "inputVariables": [ + {"name": "first"}, + {"name": "second"}, + {"name": "third"}, + ], + } + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + assert response + assert post.call_count == 3 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert json_data["mappings"] != [] + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (latest)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (latest)" + ) + + # Model version dictionary with no model version name + with pytest.raises(ValueError): + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version={}, + ) + + # Model version as a model version name string, not UUID + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = False + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version="1.0", + ) + assert response + assert post.call_count == 4 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + ) + + # Model version as a dictionary with model version name key + get_version_mock = { + "id": "3456", + "modelVersionName": "1.0", + } + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = True + get_model_or_version.return_value = get_version_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version="3456", + ) + assert response + assert post.call_count == 5 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + ) From 204789f30ec249da4c320177d0da62ffef27e6ce Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Tue, 5 Aug 2025 10:28:34 -0400 Subject: [PATCH 4/8] Refactored code for more clarity Signed-off-by: samyarpotlapalli --- src/sasctl/_services/score_definitions.py | 66 ++++++++++++++--------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index f37cfb6b..b6b396d2 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -46,7 +46,7 @@ def create_score_definition( description: str = "", server_name: str = "cas-shared-default", library_name: str = "Public", - model_version: str = "latest", + model_version: Union[str, dict] = "latest", ): """Creates the score definition service. @@ -69,7 +69,7 @@ def create_score_definition( library_name: str, optional The library within the CAS server the table exists in. Defaults to "Public". model_version: str, optional - The user-chosen version of the model with the specified model version name. Defaults to latest version. + The user-chosen version of the model. Deafaults to "latest". Returns ------- @@ -116,7 +116,7 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table and not table_file: raise HTTPError( - "This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist." + "This table may not exist in CAS. Please include the `table_file` argument." ) elif not table and table_file: cls._cas_management.upload_file( @@ -129,29 +129,8 @@ def create_score_definition( ) # Checks if the inputted table exists, and if not, uploads a file to create a new table - if model_version != "latest": - - if isinstance(model_version, dict) and "modelVersionName" in model_version: - model_version = model_version["modelVersionName"] - elif ( - isinstance(model_version, dict) - and "modelVersionName" not in model_version - ): - raise ValueError( - "Model version cannot be found. Please check the inputted model version." - ) - elif isinstance(model_version, str) and cls.is_uuid(model_version): - print("hello") - model_version = cls._model_repository.get_model_or_version( - model_id, model_version - )["modelVersionName"] - else: - model_version = model_version - - object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}" - - else: - object_uri = f"/modelManagement/models/{model_id}" + object_uri, model_version = cls.check_model_version(model_id, model_version) + # Checks if the model version is valid and how to find the name save_score_def = { "name": model_name, # used to be score_def_name @@ -185,3 +164,38 @@ def create_score_definition( "/definitions", data=json.dumps(save_score_def), headers=headers_score_def ) # The response information of the score definition can be seen as a JSON as well as a RestOBJ + + @classmethod + def check_model_version(cls, model_id: str, model_version: Union[str, dict]): + """Checks if the model version is valid. + + Parameters + ---------- + model_version : str or dict + The model version to check. + + Returns + ------- + String tuple + """ + if model_version != "latest": + + if isinstance(model_version, dict) and "modelVersionName" in model_version: + model_version = model_version["modelVersionName"] + elif ( + isinstance(model_version, dict) + and "modelVersionName" not in model_version + ): + raise ValueError("Model version cannot be found.") + elif isinstance(model_version, str) and cls.is_uuid(model_version): + print("hello") + model_version = cls._model_repository.get_model_or_version( + model_id, model_version + )["modelVersionName"] + + object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}" + + else: + object_uri = f"/modelManagement/models/{model_id}" + + return object_uri, model_version From 47ad7b375dcdcea8b9000d458e4050f7ab831ab7 Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Tue, 5 Aug 2025 11:43:40 -0400 Subject: [PATCH 5/8] Refactored code and DCO Remediation Commit for samyarpotlapalli I, samyarpotlapalli , hereby add my Signed-off-by to this commit: c01034e23bd6bd9c3905648223b6d784f68e0693I, samyarpotlapalli , hereby add my Signed-off-by to this commit: b6078ce3560dcbe18cff09737491861f82d37f78I, samyarpotlapalli , hereby add my Signed-off-by to this commit: d32a9c4a26513d4a36d9ce45fa2d511ba5cc6f73Signed-off-by: samyarpotlapalli Signed-off-by: samyarpotlapalli --- src/sasctl/_services/model_management.py | 88 ++++++++++++++--------- src/sasctl/_services/score_definitions.py | 2 +- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index e78d2549..854c755b 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -164,8 +164,7 @@ def create_performance_definition( multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all models in the project specified will be used. Defaults to None. modelVersions: str, list, optional - The name of the model version(s) for models used in the performance definition. If no model versions - are specified, all models will use the latest version. Defaults to None. + The name of the model version(s). Defaults to None, where all models are latest. library_name : str The library containing the input data, default is 'Public'. name : str, optional @@ -268,39 +267,9 @@ def create_performance_definition( "Project %s must have the 'predictionVariable' " "property set." % project.name ) - print("sup") - if not modelVersions: - updated_models = [model.id for model in models] - else: - updated_models = [] - if not isinstance(modelVersions, list): - modelVersions = [modelVersions] - - if len(models) < len(modelVersions): - raise ValueError( - "There are too many versions for the amount of models specified." - ) - - modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) - for model, modelVersionName in zip(models, modelVersions): - - if ( - isinstance(modelVersionName, dict) - and "modelVersionName" in modelVersionName - ): - modelVersionName = modelVersionName["modelVersionName"] - elif ( - isinstance(modelVersionName, dict) - and "modelVersionName" not in modelVersionName - ): - - raise ValueError("Model version is not recognized.") - - if modelVersionName != "": - updated_models.append(model.id + ":" + modelVersionName) - else: - updated_models.append(model.id) + # Creating the new array of modelIds with version names appended + updated_models = cls.check_model_versions(models, modelVersions) request = { "projectId": project.id, @@ -350,6 +319,57 @@ def create_performance_definition( }, ) + @classmethod + def check_model_versions(cls, models, modelVersions): + """ + Checking if the model version(s) are valid. Appending them to the model_id accordingly. + + Parameters + ---------- + models: list of str + List of models. + modelVersions : list of str + List of model versions associated with models. + + Returns + ------- + String list + """ + if not modelVersions: + return [model.id for model in models] + else: + updated_models = [] + if not isinstance(modelVersions, list): + modelVersions = [modelVersions] + + if len(models) < len(modelVersions): + raise ValueError( + "There are too many versions for the amount of models specified." + ) + + modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) + for model, modelVersionName in zip(models, modelVersions): + + if ( + isinstance(modelVersionName, dict) + and "modelVersionName" in modelVersionName + ): + + modelVersionName = modelVersionName["modelVersionName"] + elif ( + isinstance(modelVersionName, dict) + and "modelVersionName" not in modelVersionName + ): + + raise ValueError("Model version is not recognized.") + + if modelVersionName != "": + updated_models.append(model.id + ":" + modelVersionName) + else: + updated_models.append(model.id) + + return updated_models + @classmethod def execute_performance_definition(cls, definition): """Launches a job to run a performance definition. diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index b6b396d2..69808650 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -116,7 +116,7 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table and not table_file: raise HTTPError( - "This table may not exist in CAS. Please include the `table_file` argument." + "This table may not exist in CAS. Include the `table_file` argument." ) elif not table and table_file: cls._cas_management.upload_file( From fcaaf003debbb97d789b4eefb76db591980ed2d9 Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Tue, 5 Aug 2025 12:41:45 -0400 Subject: [PATCH 6/8] DCO Remediation Commit for samyarpotlapalli I, samyarpotlapalli , hereby add my Signed-off-by to this commit: c01034e23bd6bd9c3905648223b6d784f68e0693I, samyarpotlapalli , hereby add my Signed-off-by to this commit: b6078ce3560dcbe18cff09737491861f82d37f78I, samyarpotlapalli , hereby add my Signed-off-by to this commit: d32a9c4a26513d4a36d9ce45fa2d511ba5cc6f73Signed-off-by: samyarpotlapalli --- src/sasctl/_services/model_management.py | 58 ++++++++++++------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index 854c755b..3f13ed41 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -44,8 +44,8 @@ def publish_model( The name or id of the model, or a dictionary representation of the model. destination : str Name of destination to publish the model to. - model_version_id : str or dict, optional - Provide the id, name, or dictionary representation of the version to publish. Defaults to 'latest'. + model_version : str or dict, optional + Provide the version id, name, or dict to publish. Defaults to 'latest'. name : str, optional Provide a custom name for the published model. Defaults to None. force : bool, optional @@ -164,7 +164,7 @@ def create_performance_definition( multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all models in the project specified will be used. Defaults to None. modelVersions: str, list, optional - The name of the model version(s). Defaults to None, where all models are latest. + The name of the model version(s). Defaults to None, i.e. all models are latest. library_name : str The library containing the input data, default is 'Public'. name : str, optional @@ -322,7 +322,7 @@ def create_performance_definition( @classmethod def check_model_versions(cls, models, modelVersions): """ - Checking if the model version(s) are valid. Appending them to the model_id accordingly. + Checking if the model version(s) are valid and append to model id accordingly. Parameters ---------- @@ -337,38 +337,38 @@ def check_model_versions(cls, models, modelVersions): """ if not modelVersions: return [model.id for model in models] - else: - updated_models = [] - if not isinstance(modelVersions, list): - modelVersions = [modelVersions] - if len(models) < len(modelVersions): - raise ValueError( - "There are too many versions for the amount of models specified." - ) + updated_models = [] + if not isinstance(modelVersions, list): + modelVersions = [modelVersions] - modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) - for model, modelVersionName in zip(models, modelVersions): + if len(models) < len(modelVersions): + raise ValueError( + "There are too many versions for the amount of models specified." + ) - if ( - isinstance(modelVersionName, dict) - and "modelVersionName" in modelVersionName - ): + modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) + for model, modelVersionName in zip(models, modelVersions): - modelVersionName = modelVersionName["modelVersionName"] - elif ( - isinstance(modelVersionName, dict) - and "modelVersionName" not in modelVersionName - ): + if ( + isinstance(modelVersionName, dict) + and "modelVersionName" in modelVersionName + ): - raise ValueError("Model version is not recognized.") + modelVersionName = modelVersionName["modelVersionName"] + elif ( + isinstance(modelVersionName, dict) + and "modelVersionName" not in modelVersionName + ): - if modelVersionName != "": - updated_models.append(model.id + ":" + modelVersionName) - else: - updated_models.append(model.id) + raise ValueError("Model version is not recognized.") + + if modelVersionName != "": + updated_models.append(model.id + ":" + modelVersionName) + else: + updated_models.append(model.id) - return updated_models + return updated_models @classmethod def execute_performance_definition(cls, definition): From e16be6517658cb4884dc91fc57f019370c5bfbbb Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Tue, 5 Aug 2025 14:07:50 -0400 Subject: [PATCH 7/8] Added another unit test for more coverage Signed-off-by: samyarpotlapalli --- src/sasctl/_services/score_definitions.py | 1 - tests/unit/test_score_definitions.py | 26 ++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index 69808650..05733d2b 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -188,7 +188,6 @@ def check_model_version(cls, model_id: str, model_version: Union[str, dict]): ): raise ValueError("Model version cannot be found.") elif isinstance(model_version, str) and cls.is_uuid(model_version): - print("hello") model_version = cls._model_repository.get_model_or_version( model_id, model_version )["modelVersionName"] diff --git a/tests/unit/test_score_definitions.py b/tests/unit/test_score_definitions.py index 075b316c..1ebdc462 100644 --- a/tests/unit/test_score_definitions.py +++ b/tests/unit/test_score_definitions.py @@ -198,6 +198,30 @@ def test_create_score_definition(): == "test_model (1.0)" ) + # Model version as a dict with modelVersionName key + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = False + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version={"modelVersionName": "1.0"}, + ) + assert response + assert post.call_count == 5 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + ) + # Model version as a dictionary with model version name key get_version_mock = { "id": "3456", @@ -214,7 +238,7 @@ def test_create_score_definition(): model_version="3456", ) assert response - assert post.call_count == 5 + assert post.call_count == 6 data = post.call_args json_data = json.loads(data.kwargs["data"]) From 30f7ed6031e42c3377c5c771cf9ee79b66a79e08 Mon Sep 17 00:00:00 2001 From: samyarpotlapalli Date: Tue, 5 Aug 2025 14:10:30 -0400 Subject: [PATCH 8/8] Additional refactoring Signed-off-by: samyarpotlapalli --- src/sasctl/_services/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index 3f13ed41..7950d95e 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -164,7 +164,7 @@ def create_performance_definition( multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all models in the project specified will be used. Defaults to None. modelVersions: str, list, optional - The name of the model version(s). Defaults to None, i.e. all models are latest. + The name of the model version(s). Defaults to None, so all models are latest. library_name : str The library containing the input data, default is 'Public'. name : str, optional