Skip to content

Commit e16be65

Browse files
Added another unit test for more coverage
Signed-off-by: samyarpotlapalli <[email protected]>
1 parent fcaaf00 commit e16be65

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

src/sasctl/_services/score_definitions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def check_model_version(cls, model_id: str, model_version: Union[str, dict]):
188188
):
189189
raise ValueError("Model version cannot be found.")
190190
elif isinstance(model_version, str) and cls.is_uuid(model_version):
191-
print("hello")
192191
model_version = cls._model_repository.get_model_or_version(
193192
model_id, model_version
194193
)["modelVersionName"]

tests/unit/test_score_definitions.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,30 @@ def test_create_score_definition():
198198
== "test_model (1.0)"
199199
)
200200

201+
# Model version as a dict with modelVersionName key
202+
get_model.return_value = get_model_mock
203+
get_table.return_value = get_table_mock
204+
is_uuid.return_value = False
205+
response = sd.create_score_definition(
206+
score_def_name="test_create_sd",
207+
model="12345",
208+
table_name="test_table",
209+
model_version={"modelVersionName": "1.0"},
210+
)
211+
assert response
212+
assert post.call_count == 5
213+
214+
data = post.call_args
215+
json_data = json.loads(data.kwargs["data"])
216+
assert (
217+
json_data["objectDescriptor"]["name"]
218+
== "test_model (1.0)"
219+
)
220+
assert (
221+
json_data["properties"]["versionedModel"]
222+
== "test_model (1.0)"
223+
)
224+
201225
# Model version as a dictionary with model version name key
202226
get_version_mock = {
203227
"id": "3456",
@@ -214,7 +238,7 @@ def test_create_score_definition():
214238
model_version="3456",
215239
)
216240
assert response
217-
assert post.call_count == 5
241+
assert post.call_count == 6
218242

219243
data = post.call_args
220244
json_data = json.loads(data.kwargs["data"])

0 commit comments

Comments
 (0)