Skip to content

Commit bbbf4e7

Browse files
committed
Update tests to include more than one model in a create_performance_definition call
1 parent 19deb0c commit bbbf4e7

File tree

1 file changed

+130
-59
lines changed

1 file changed

+130
-59
lines changed

tests/unit/test_model_management.py

Lines changed: 130 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def test_create_performance_definition():
1616
from sasctl import current_session
1717

1818
PROJECT = RestObj({'name': 'Test Project', 'id': '98765'})
19-
MODEL = RestObj({'name': 'Test Model', 'id': '12345', 'projectId': PROJECT['id']})
19+
MODELS = [RestObj({'name': 'Test Model 1', 'id': '12345', 'projectId': PROJECT['id']}),
20+
RestObj({'name': 'Test Model 2', 'id': '67890', 'projectId': PROJECT['id']})]
2021
USER = 'username'
2122

2223
with mock.patch('sasctl.core.Session.get_auth'):
@@ -29,76 +30,146 @@ def test_create_performance_definition():
2930
'sasctl._services.model_repository.ModelRepository' '.get_project'
3031
) as get_project:
3132
with mock.patch(
32-
'sasctl._services.model_management.ModelManagement' '.post'
33-
) as post:
34-
get_model.return_value = MODEL
33+
'sasctl._services.model_repository.ModelRepository' '.list_models'
34+
) as list_models:
35+
with mock.patch(
36+
'sasctl._services.model_management.ModelManagement' '.post'
37+
) as post_models:
38+
list_models.return_value = copy.deepcopy(MODELS)
39+
get_model.side_effect = copy.deepcopy(MODELS)
40+
41+
with pytest.raises(ValueError):
42+
# Function call missing both models and project arguments
43+
_ = mm.create_performance_definition(
44+
library_name='TestLibrary',
45+
table_prefix='TestData'
46+
)
47+
48+
with pytest.raises(ValueError):
49+
# Project missing all required properties & specified by model
50+
get_project.return_value = copy.deepcopy(PROJECT)
51+
_ = mm.create_performance_definition(
52+
models=['model1', 'model2'],
53+
library_name='TestLibrary',
54+
table_prefix='TestData'
55+
)
56+
57+
with pytest.raises(ValueError):
58+
# Project missing all required properties & specified by project
59+
get_model.reset_mock(side_effect=True)
60+
_ = mm.create_performance_definition(
61+
project='project',
62+
library_name='TestLibrary',
63+
table_prefix='TestData'
64+
)
65+
66+
with pytest.raises(ValueError):
67+
# Project missing some required properties
68+
get_model.reset_mock(side_effect=True)
69+
get_project.return_value = copy.deepcopy(PROJECT)
70+
get_project.return_value['targetVariable'] = 'target'
71+
_ = mm.create_performance_definition(
72+
models=['model1', 'model2'],
73+
library_name='TestLibrary',
74+
table_prefix='TestData'
75+
)
76+
77+
with pytest.raises(ValueError):
78+
# Project missing some required properties
79+
get_model.reset_mock()
80+
get_project.return_value = copy.deepcopy(PROJECT)
81+
get_project.return_value['targetLevel'] = 'interval'
82+
_ = mm.create_performance_definition(
83+
models=['model1', 'model2'],
84+
library_name='TestLibrary',
85+
table_prefix='TestData'
86+
)
87+
88+
with pytest.raises(ValueError):
89+
# Project missing some required properties
90+
get_model.reset_mock()
91+
get_project.return_value = copy.deepcopy(PROJECT)
92+
get_project.return_value['function'] = 'classification'
93+
_ = mm.create_performance_definition(
94+
models=['model1', 'model2'],
95+
library_name='TestLibrary',
96+
table_prefix='TestData'
97+
)
98+
99+
with pytest.raises(ValueError):
100+
# Project missing some required properties
101+
get_model.reset_mock()
102+
get_project.return_value = copy.deepcopy(PROJECT)
103+
get_project.return_value['function'] = 'prediction'
104+
_ = mm.create_performance_definition(
105+
models=['model1', 'model2'],
106+
library_name='TestLibrary',
107+
table_prefix='TestData'
108+
)
35109

36-
with pytest.raises(ValueError):
37-
# Project missing all required properties
38-
get_project.return_value = copy.deepcopy(PROJECT)
39-
_ = mm.create_performance_definition(
40-
'model', 'TestLibrary', 'TestData'
41-
)
42-
43-
with pytest.raises(ValueError):
44-
# Project missing some required properties
45110
get_project.return_value = copy.deepcopy(PROJECT)
46111
get_project.return_value['targetVariable'] = 'target'
47-
_ = mm.create_performance_definition(
48-
'model', 'TestLibrary', 'TestData'
49-
)
50-
51-
with pytest.raises(ValueError):
52-
# Project missing some required properties
53-
get_project.return_value = copy.deepcopy(PROJECT)
54112
get_project.return_value['targetLevel'] = 'interval'
113+
get_project.return_value['predictionVariable'] = 'predicted'
114+
get_project.return_value['function'] = 'prediction'
115+
get_model.side_effect = copy.deepcopy(MODELS)
55116
_ = mm.create_performance_definition(
56-
'model', 'TestLibrary', 'TestData'
57-
)
58-
59-
with pytest.raises(ValueError):
60-
# Project missing some required properties
61-
get_project.return_value = copy.deepcopy(PROJECT)
62-
get_project.return_value['function'] = 'classification'
63-
_ = mm.create_performance_definition(
64-
'model', 'TestLibrary', 'TestData'
117+
models=['model1', 'model2'],
118+
library_name='TestLibrary',
119+
table_prefix='TestData',
120+
max_bins=3,
121+
monitor_challenger=True,
122+
monitor_champion=True,
65123
)
66124

67-
with pytest.raises(ValueError):
68-
# Project missing some required properties
125+
assert post_models.call_count == 1
126+
url, data = post_models.call_args
127+
128+
assert PROJECT['id'] == data['json']['projectId']
129+
assert MODELS[0]['id'] in data['json']['modelIds']
130+
assert MODELS[1]['id'] in data['json']['modelIds']
131+
assert 'TestLibrary' == data['json']['dataLibrary']
132+
assert 'TestData' == data['json']['dataPrefix']
133+
assert 'cas-shared-default' == data['json']['casServerId']
134+
assert data['json']['name']
135+
assert data['json']['description']
136+
assert data['json']['maxBins'] == 3
137+
assert data['json']['championMonitored'] is True
138+
assert data['json']['challengerMonitored'] is True
139+
140+
with mock.patch(
141+
'sasctl._services.model_management.ModelManagement' '.post'
142+
) as post_project:
143+
list_models.return_value = copy.deepcopy(MODELS)
69144
get_project.return_value = copy.deepcopy(PROJECT)
145+
get_project.return_value['targetVariable'] = 'target'
146+
get_project.return_value['targetLevel'] = 'interval'
147+
get_project.return_value['predictionVariable'] = 'predicted'
70148
get_project.return_value['function'] = 'prediction'
149+
get_model.side_effect = copy.deepcopy(MODELS)
71150
_ = mm.create_performance_definition(
72-
'model', 'TestLibrary', 'TestData'
151+
project='project',
152+
library_name='TestLibrary',
153+
table_prefix='TestData',
154+
max_bins=3,
155+
monitor_challenger=True,
156+
monitor_champion=True,
73157
)
74158

75-
get_project.return_value = copy.deepcopy(PROJECT)
76-
get_project.return_value['targetVariable'] = 'target'
77-
get_project.return_value['targetLevel'] = 'interval'
78-
get_project.return_value['predictionVariable'] = 'predicted'
79-
get_project.return_value['function'] = 'prediction'
80-
_ = mm.create_performance_definition(
81-
'model',
82-
'TestLibrary',
83-
'TestData',
84-
max_bins=3,
85-
monitor_challenger=True,
86-
monitor_champion=True,
87-
)
88-
89-
assert post.call_count == 1
90-
url, data = post.call_args
91-
92-
assert PROJECT['id'] == data['json']['projectId']
93-
assert MODEL['id'] in data['json']['modelIds']
94-
assert 'TestLibrary' == data['json']['dataLibrary']
95-
assert 'TestData' == data['json']['dataPrefix']
96-
assert 'cas-shared-default' == data['json']['casServerId']
97-
assert data['json']['name'] is not None
98-
assert data['json']['description'] is not None
99-
assert data['json']['maxBins'] == 3
100-
assert data['json']['championMonitored'] == True
101-
assert data['json']['challengerMonitored'] == True
159+
assert post_project.call_count == 1
160+
url, data = post_project.call_args
161+
162+
assert PROJECT['id'] == data['json']['projectId']
163+
assert MODELS[0]['id'] in data['json']['modelIds']
164+
assert MODELS[1]['id'] in data['json']['modelIds']
165+
assert 'TestLibrary' == data['json']['dataLibrary']
166+
assert 'TestData' == data['json']['dataPrefix']
167+
assert 'cas-shared-default' == data['json']['casServerId']
168+
assert data['json']['name']
169+
assert data['json']['description']
170+
assert data['json']['maxBins'] == 3
171+
assert data['json']['championMonitored'] is True
172+
assert data['json']['challengerMonitored'] is True
102173

103174
def test_table_prefix_format():
104175
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)