Skip to content

Commit 47ad7b3

Browse files
Refactored code and DCO Remediation Commit for samyarpotlapalli <[email protected]>I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: c01034e23bd6bd9c3905648223b6d784f68e0693I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: b6078ce3560dcbe18cff09737491861f82d37f78I, samyarpotlapalli <[email protected]>, hereby add my Signed-off-by to this commit: d32a9c4a26513d4a36d9ce45fa2d511ba5cc6f73Signed-off-by: samyarpotlapalli <[email protected]>
Signed-off-by: samyarpotlapalli <[email protected]>
1 parent 204789f commit 47ad7b3

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

src/sasctl/_services/model_management.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ def create_performance_definition(
164164
multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all
165165
models in the project specified will be used. Defaults to None.
166166
modelVersions: str, list, optional
167-
The name of the model version(s) for models used in the performance definition. If no model versions
168-
are specified, all models will use the latest version. Defaults to None.
167+
The name of the model version(s). Defaults to None, where all models are latest.
169168
library_name : str
170169
The library containing the input data, default is 'Public'.
171170
name : str, optional
@@ -268,39 +267,9 @@ def create_performance_definition(
268267
"Project %s must have the 'predictionVariable' "
269268
"property set." % project.name
270269
)
271-
print("sup")
272-
if not modelVersions:
273-
updated_models = [model.id for model in models]
274-
else:
275-
updated_models = []
276-
if not isinstance(modelVersions, list):
277-
modelVersions = [modelVersions]
278-
279-
if len(models) < len(modelVersions):
280-
raise ValueError(
281-
"There are too many versions for the amount of models specified."
282-
)
283-
284-
modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
285-
for model, modelVersionName in zip(models, modelVersions):
286-
287-
if (
288-
isinstance(modelVersionName, dict)
289-
and "modelVersionName" in modelVersionName
290-
):
291270

292-
modelVersionName = modelVersionName["modelVersionName"]
293-
elif (
294-
isinstance(modelVersionName, dict)
295-
and "modelVersionName" not in modelVersionName
296-
):
297-
298-
raise ValueError("Model version is not recognized.")
299-
300-
if modelVersionName != "":
301-
updated_models.append(model.id + ":" + modelVersionName)
302-
else:
303-
updated_models.append(model.id)
271+
# Creating the new array of modelIds with version names appended
272+
updated_models = cls.check_model_versions(models, modelVersions)
304273

305274
request = {
306275
"projectId": project.id,
@@ -350,6 +319,57 @@ def create_performance_definition(
350319
},
351320
)
352321

322+
@classmethod
323+
def check_model_versions(cls, models, modelVersions):
324+
"""
325+
Checking if the model version(s) are valid. Appending them to the model_id accordingly.
326+
327+
Parameters
328+
----------
329+
models: list of str
330+
List of models.
331+
modelVersions : list of str
332+
List of model versions associated with models.
333+
334+
Returns
335+
-------
336+
String list
337+
"""
338+
if not modelVersions:
339+
return [model.id for model in models]
340+
else:
341+
updated_models = []
342+
if not isinstance(modelVersions, list):
343+
modelVersions = [modelVersions]
344+
345+
if len(models) < len(modelVersions):
346+
raise ValueError(
347+
"There are too many versions for the amount of models specified."
348+
)
349+
350+
modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
351+
for model, modelVersionName in zip(models, modelVersions):
352+
353+
if (
354+
isinstance(modelVersionName, dict)
355+
and "modelVersionName" in modelVersionName
356+
):
357+
358+
modelVersionName = modelVersionName["modelVersionName"]
359+
elif (
360+
isinstance(modelVersionName, dict)
361+
and "modelVersionName" not in modelVersionName
362+
):
363+
364+
raise ValueError("Model version is not recognized.")
365+
366+
if modelVersionName != "":
367+
updated_models.append(model.id + ":" + modelVersionName)
368+
else:
369+
updated_models.append(model.id)
370+
371+
return updated_models
372+
353373
@classmethod
354374
def execute_performance_definition(cls, definition):
355375
"""Launches a job to run a performance definition.

src/sasctl/_services/score_definitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def create_score_definition(
116116
table = cls._cas_management.get_table(table_name, library_name, server_name)
117117
if not table and not table_file:
118118
raise HTTPError(
119-
"This table may not exist in CAS. Please include the `table_file` argument."
119+
"This table may not exist in CAS. Include the `table_file` argument."
120120
)
121121
elif not table and table_file:
122122
cls._cas_management.upload_file(

0 commit comments

Comments
 (0)