Skip to content

Commit dbc6ab9

Browse files
committed
Merge remote-tracking branch 'origin/refactorPZMM' into refactorPZMM
# Conflicts: # src/sasctl/pzmm/model_parameters.py
2 parents b3fbb90 + c680342 commit dbc6ab9

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

src/sasctl/_services/model_management.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,22 @@ def create_performance_definition(
204204

205205
# Separate single models from multiple models
206206
if not isinstance(models, list):
207-
models = mr.get_model(models)
207+
models = [models]
208+
if project:
209+
project = mr.get_project(project)
210+
project_models = mr.get(f'/projects/{project.id}/models')
211+
project_models = [m for m in project_models if m.name in models]
212+
models = project_models
213+
# Necessary to eventually provide variables to the performance definition
214+
models[0] = mr.get_model(project_models[0].id)
208215
else:
209216
# Collect all models into a list. This converts the PagedList response from mr.list_models to a normal list.
210217
for i, model in enumerate(models):
211218
models[i] = mr.get_model(model)
212-
if not project:
213219
project = mr.get_project(models[0].projectId)
220+
# Ensures that all models are in the same project
221+
if not all([model.projectId == project.id for model in models]):
222+
raise ValueError("Not all models are contained within the same project. Try specifying a project.")
214223

215224
# Performance data cannot be captured unless certain project properties have been configured.
216225
for required in ["targetVariable", "targetLevel"]:

src/sasctl/pzmm/model_parameters.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,13 @@ def update_kpis(
141141
models_to_update = kpis["ModelUUID"].unique().tolist()
142142

143143
for model in models_to_update:
144-
try:
145-
current_params, file_name = _find_file(model, "hyperparameters")
146-
except:
147-
model_name = {kpis.loc[kpis["ModelUUID"]==model, "ModelName"].iloc[0]}
148-
print(f'No hyperparameter file for current model {model_name}. Attempting for next model...')
149-
else:
150-
updated_json = cls._update_json(model, current_params.json(), kpis)
151-
mr.add_model_content(model, json.dumps(updated_json, indent=4), file_name)
144+
try:
145+
current_params, file_name = _find_file(model, "hyperparameters")
146+
except:
147+
print(f'No hyperparamter file for current model {kpis.loc[kpis["ModelUUID"]==model, "ModelName"].iloc[0]}. Attempting for next model...')
148+
else:
149+
updated_json = cls._update_json(model, current_params, kpis)
150+
mr.add_model_content(model, json.dumps(updated_json, indent=4), file_name)
152151

153152
@staticmethod
154153
def get_hyperparameters(model: Union[str, dict, RestObj]) -> Tuple[dict, str]:
@@ -175,7 +174,7 @@ def get_hyperparameters(model: Union[str, dict, RestObj]) -> Tuple[dict, str]:
175174
model = mr.get_model(model)
176175
id_ = model["id"]
177176
file_contents, file_name = _find_file(id_, "hyperparameters")
178-
return file_contents.hyperparameters, file_name
177+
return file_contents, file_name
179178

180179
@classmethod
181180
def add_hyperparameters(cls, model: Union[str, dict, RestObj], **kwargs) -> None:
@@ -201,7 +200,7 @@ def add_hyperparameters(cls, model: Union[str, dict, RestObj], **kwargs) -> None
201200
id_ = model["id"]
202201
hyperparameters, file_name = cls.get_hyperparameters(id_)
203202
for key, value in kwargs.items():
204-
hyperparameters[key] = value
203+
hyperparameters["hyperparameters"][key] = value
205204
mr.add_model_content(
206205
model,
207206
json.dumps(hyperparameters, indent=4),

0 commit comments

Comments
 (0)