Skip to content

Commit d32a9c4

Browse files
Added unit tests for model version functionality
1 parent b6078ce commit d32a9c4

File tree

4 files changed

+308
-118
lines changed

4 files changed

+308
-118
lines changed

src/sasctl/_services/model_management.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ def publish_model(
7979
if model_version != "latest":
8080
if isinstance(model_version, dict) and "modelVersionName" in model_version:
8181
model_version_name = model_version["modelVersionName"]
82+
elif (
83+
isinstance(model_version, dict)
84+
and "modelVersionName" not in model_version
85+
):
86+
raise ValueError("Model version is not recognized.")
8287
elif isinstance(model_version, str) and cls.is_uuid(model_version):
8388
model_version_name = mr.get_model_or_version(model, model_version)[
8489
"modelVersionName"
@@ -263,7 +268,7 @@ def create_performance_definition(
263268
"Project %s must have the 'predictionVariable' "
264269
"property set." % project.name
265270
)
266-
271+
print("sup")
267272
if not modelVersions:
268273
updated_models = [model.id for model in models]
269274
else:
@@ -278,22 +283,29 @@ def create_performance_definition(
278283

279284
modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
280285
for model, modelVersionName in zip(models, modelVersions):
286+
281287
if (
282288
isinstance(modelVersionName, dict)
283289
and "modelVersionName" in modelVersionName
284290
):
291+
285292
modelVersionName = modelVersionName["modelVersionName"]
286293
elif (
287294
isinstance(modelVersionName, dict)
288295
and "modelVersionName" not in modelVersionName
289296
):
297+
290298
raise ValueError("Model version is not recognized.")
291-
updated_models.append(model.id + ":" + modelVersionName)
299+
300+
if modelVersionName != "":
301+
updated_models.append(model.id + ":" + modelVersionName)
302+
else:
303+
updated_models.append(model.id)
292304

293305
request = {
294306
"projectId": project.id,
295307
"name": name or project.name + " Performance",
296-
"modelIds": [model for model in updated_models],
308+
"modelIds": updated_models,
297309
"championMonitored": monitor_champion,
298310
"challengerMonitored": monitor_challenger,
299311
"maxBins": max_bins,
@@ -330,7 +342,6 @@ def create_performance_definition(
330342
for v in project.get("variables", [])
331343
if v.get("role") == "output"
332344
]
333-
334345
return cls.post(
335346
"/performanceTasks",
336347
json=request,

src/sasctl/_services/score_definitions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,15 @@ def create_score_definition(
133133

134134
if isinstance(model_version, dict) and "modelVersionName" in model_version:
135135
model_version = model_version["modelVersionName"]
136+
elif (
137+
isinstance(model_version, dict)
138+
and "modelVersionName" not in model_version
139+
):
140+
raise ValueError(
141+
"Model version cannot be found. Please check the inputted model version."
142+
)
136143
elif isinstance(model_version, str) and cls.is_uuid(model_version):
144+
print("hello")
137145
model_version = cls._model_repository.get_model_or_version(
138146
model_id, model_version
139147
)["modelVersionName"]

tests/unit/test_model_management.py

Lines changed: 122 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def test_create_performance_definition():
2323
RestObj({"name": "Test Model 2", "id": "67890", "projectId": PROJECT["id"]}),
2424
]
2525
USER = "username"
26+
VERSION_MOCK = {"modelVersionName": "1.0"}
27+
VERSION_MOCK_NONAME = {}
2628

2729
with mock.patch("sasctl.core.Session._get_authorization_token"):
2830
current_session("example.com", USER, "password")
@@ -111,6 +113,32 @@ def test_create_performance_definition():
111113
table_prefix="TestData",
112114
)
113115

116+
with pytest.raises(ValueError):
117+
# Model verions exceeds models
118+
get_model.side_effect = copy.deepcopy(MODELS)
119+
_ = mm.create_performance_definition(
120+
models=["model1", "model2"],
121+
modelVersions=["1.0", "2.0", "3.0"],
122+
library_name="TestLibrary",
123+
table_prefix="TestData",
124+
max_bins=3,
125+
monitor_challenger=True,
126+
monitor_champion=True,
127+
)
128+
129+
with pytest.raises(ValueError):
130+
# Model version dictionary missing modelVersionName
131+
get_model.side_effect = copy.deepcopy(MODELS)
132+
_ = mm.create_performance_definition(
133+
models=["model1", "model2"],
134+
modelVersions=VERSION_MOCK_NONAME,
135+
library_name="TestLibrary",
136+
table_prefix="TestData",
137+
max_bins=3,
138+
monitor_challenger=True,
139+
monitor_champion=True,
140+
)
141+
114142
get_project.return_value = copy.deepcopy(PROJECT)
115143
get_project.return_value["targetVariable"] = "target"
116144
get_project.return_value["targetLevel"] = "interval"
@@ -125,21 +153,68 @@ def test_create_performance_definition():
125153
monitor_challenger=True,
126154
monitor_champion=True,
127155
)
156+
url, data = post_models.call_args
157+
assert post_models.call_count == 1
158+
assert PROJECT["id"] == data["json"]["projectId"]
159+
assert MODELS[0]["id"] in data["json"]["modelIds"]
160+
assert MODELS[1]["id"] in data["json"]["modelIds"]
161+
assert "TestLibrary" == data["json"]["dataLibrary"]
162+
assert "TestData" == data["json"]["dataPrefix"]
163+
assert "cas-shared-default" == data["json"]["casServerId"]
164+
assert data["json"]["name"]
165+
assert data["json"]["description"]
166+
assert data["json"]["maxBins"] == 3
167+
assert data["json"]["championMonitored"] is True
168+
assert data["json"]["challengerMonitored"] is True
128169

129-
assert post_models.call_count == 1
130-
url, data = post_models.call_args
131-
132-
assert PROJECT["id"] == data["json"]["projectId"]
133-
assert MODELS[0]["id"] in data["json"]["modelIds"]
134-
assert MODELS[1]["id"] in data["json"]["modelIds"]
135-
assert "TestLibrary" == data["json"]["dataLibrary"]
136-
assert "TestData" == data["json"]["dataPrefix"]
137-
assert "cas-shared-default" == data["json"]["casServerId"]
138-
assert data["json"]["name"]
139-
assert data["json"]["description"]
140-
assert data["json"]["maxBins"] == 3
141-
assert data["json"]["championMonitored"] is True
142-
assert data["json"]["challengerMonitored"] is True
170+
get_model.side_effect = copy.deepcopy(MODELS)
171+
_ = mm.create_performance_definition(
172+
# One model version as a string name
173+
models=["model1", "model2"],
174+
modelVersions="1.0",
175+
library_name="TestLibrary",
176+
table_prefix="TestData",
177+
max_bins=3,
178+
monitor_challenger=True,
179+
monitor_champion=True,
180+
)
181+
182+
assert post_models.call_count == 2
183+
url, data = post_models.call_args
184+
assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"]
185+
assert MODELS[1]["id"] in data["json"]["modelIds"]
186+
187+
get_model.side_effect = copy.deepcopy(MODELS)
188+
# List of string type model versions
189+
_ = mm.create_performance_definition(
190+
models=["model1", "model2"],
191+
modelVersions=["1.0", "2.0"],
192+
library_name="TestLibrary",
193+
table_prefix="TestData",
194+
max_bins=3,
195+
monitor_challenger=True,
196+
monitor_champion=True,
197+
)
198+
assert post_models.call_count == 3
199+
url, data = post_models.call_args
200+
assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"]
201+
assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"]
202+
203+
get_model.side_effect = copy.deepcopy(MODELS)
204+
# List of dictionary type and string type model versions
205+
_ = mm.create_performance_definition(
206+
models=["model1", "model2"],
207+
modelVersions=[VERSION_MOCK, "2.0"],
208+
library_name="TestLibrary",
209+
table_prefix="TestData",
210+
max_bins=3,
211+
monitor_challenger=True,
212+
monitor_champion=True,
213+
)
214+
assert post_models.call_count == 4
215+
url, data = post_models.call_args
216+
assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"]
217+
assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"]
143218

144219
with mock.patch(
145220
"sasctl._services.model_management.ModelManagement" ".post"
@@ -160,20 +235,39 @@ def test_create_performance_definition():
160235
monitor_champion=True,
161236
)
162237

163-
assert post_project.call_count == 1
164-
url, data = post_project.call_args
165-
166-
assert PROJECT["id"] == data["json"]["projectId"]
167-
assert MODELS[0]["id"] in data["json"]["modelIds"]
168-
assert MODELS[1]["id"] in data["json"]["modelIds"]
169-
assert "TestLibrary" == data["json"]["dataLibrary"]
170-
assert "TestData" == data["json"]["dataPrefix"]
171-
assert "cas-shared-default" == data["json"]["casServerId"]
172-
assert data["json"]["name"]
173-
assert data["json"]["description"]
174-
assert data["json"]["maxBins"] == 3
175-
assert data["json"]["championMonitored"] is True
176-
assert data["json"]["challengerMonitored"] is True
238+
# one extra test for project with version id
239+
240+
assert post_project.call_count == 1
241+
url, data = post_project.call_args
242+
243+
assert PROJECT["id"] == data["json"]["projectId"]
244+
assert MODELS[0]["id"] in data["json"]["modelIds"]
245+
assert MODELS[1]["id"] in data["json"]["modelIds"]
246+
assert "TestLibrary" == data["json"]["dataLibrary"]
247+
assert "TestData" == data["json"]["dataPrefix"]
248+
assert "cas-shared-default" == data["json"]["casServerId"]
249+
assert data["json"]["name"]
250+
assert data["json"]["description"]
251+
assert data["json"]["maxBins"] == 3
252+
assert data["json"]["championMonitored"] is True
253+
assert data["json"]["challengerMonitored"] is True
254+
255+
get_model.side_effect = copy.deepcopy(MODELS)
256+
# Project with model version
257+
_ = mm.create_performance_definition(
258+
project="project",
259+
modelVersions="2.0",
260+
library_name="TestLibrary",
261+
table_prefix="TestData",
262+
max_bins=3,
263+
monitor_challenger=True,
264+
monitor_champion=True,
265+
)
266+
267+
assert post_project.call_count == 2
268+
url, data = post_project.call_args
269+
assert f"{MODELS[0]['id']}:2.0" in data["json"]["modelIds"]
270+
assert MODELS[1]["id"] in data["json"]["modelIds"]
177271

178272
def test_table_prefix_format():
179273
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)