Skip to content

Commit d9f39c7

Browse files
committed
Refactor to PEP8 standards and clarify doc_strings
1 parent 5bd985e commit d9f39c7

File tree

1 file changed

+72
-70
lines changed

1 file changed

+72
-70
lines changed

src/sasctl/pzmm/modelParameters.py

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,55 +7,70 @@
77
# TODO: Convert STRINGIO calls to string or dict format
88

99

10-
def find_file(model, fileName):
10+
def _find_file(model, file_name):
1111
"""
12-
Retrieves first file from specified model that contains fileName as a substring.
12+
Retrieves the first file from a registered model on SAS Model Manager that contains the provided
13+
file_name as an exact match or substring.
14+
1315
Parameters
1416
----------
15-
model : str
16-
ID of the model the desired file is located in
17-
fileName : str
18-
The name of the desired file, or a substring that is contained within the file name
17+
model : str or dict
18+
The name or id of the model, or a dictionary representation of the model.
19+
file_name : str
20+
The name of the desired file, or a substring that is contained within the file name.
21+
1922
Returns
2023
-------
2124
RestObj
22-
The first file with a name containing fileName
23-
25+
The first file with a name containing file_name.
2426
"""
2527
from ..core import current_session
2628

2729
sess = current_session()
28-
fileList = mr.get_model_contents(model)
29-
for file in fileList:
30+
file_list = mr.get_model_contents(model)
31+
for file in file_list:
3032
print(file.name)
31-
if fileName.lower() in file.name.lower():
32-
correctFile = sess.get(
33-
"modelRepository/models/{}/contents/{}/content".format(model, file.id)
34-
)
35-
break
36-
return correctFile
33+
if file_name.lower() in file.name.lower():
34+
correct_file = sess.get(f"modelRepository/models/{model}/contents/{file.id}/content")
35+
return correct_file
3736

3837

3938
class ModelParameters:
4039
@classmethod
41-
def generate_hyperparameters(cls, model, modelPrefix, pPath):
40+
def generate_hyperparameters(cls, model, model_prefix, pickle_path):
4241
"""
43-
Generates hyperparameters for a given model
42+
Generates hyperparameters for a given model and creates a JSON file representation.
43+
44+
Currently only supports generation of scikit-learn model hyperparameters.
45+
4446
Parameters
4547
----------
46-
model : str, list, dict
47-
Name, id, or dictionary representation of the model
48-
modelPrefix : str
49-
Name used to create model files
50-
e.g. (modelPrefix) + "Hyperparameters.json")
51-
pPath : str, Path
52-
Directory location of model files
48+
model : Python object
49+
Python object representing the model.
50+
model_prefix : str
51+
Name used to create model files. (e.g. (modelPrefix) + "Hyperparameters.json")
52+
pickle_path : str, Path
53+
Directory location of model files.
54+
55+
Yields
56+
------
57+
JSON file
58+
Named {model_prefix}Hyperparameters.json.
5359
"""
60+
def sklearn_params():
61+
"""
62+
Generates hyperparameters for the models generated by scikit-learn.
63+
"""
64+
hyperparameters = model.get_params()
65+
model_json = {"hyperparameters": hyperparameters}
66+
with open(Path(pickle_path) / f"{model_prefix}Hyperparameters.json", "w") as f:
67+
f.write(json.dumps(model_json, indent=4))
68+
5469
if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
55-
cls.sklearn_params(model, modelPrefix, pPath)
70+
sklearn_params(model, model_prefix, pickle_path)
5671
else:
5772
raise ValueError(
58-
"Other model types not currently supported for hyperparameter generation."
73+
"This model type is not currently supported for hyperparameter generation."
5974
)
6075

6176
@classmethod
@@ -68,10 +83,11 @@ def update_kpis(
6883
"""
6984
Updates hyperparameter file to include KPIs generated by performance definitions, as well
7085
as any custom KPIs imported by user to the SAS KPI data table.
86+
7187
Parameters
7288
----------
73-
project : dict, str, list
74-
Name, id, or dictionary representation of the project
89+
project : str or dict
90+
The name or id of the project, or a dictionary representation of the project.
7591
server : str, optional
7692
Server on which the KPI data table is stored. Defaults to "cas-shared-default".
7793
caslib : str, optional
@@ -81,36 +97,38 @@ def update_kpis(
8197
from io import StringIO
8298

8399
kpis = get_project_kpis(project, server, caslib)
84-
modelsToUpdate = kpis["ModelUUID"].unique().tolist()
85-
for model in modelsToUpdate:
86-
currentParams = find_file(model, "hyperparameters")
87-
currentJSON = currentParams.json()
88-
modelRows = kpis.loc[kpis["ModelUUID"] == model]
89-
modelRows.set_index("TimeLabel", inplace=True)
90-
kpiJSON = modelRows.to_json(orient="index")
91-
parsedJSON = json.loads(kpiJSON)
92-
currentJSON["kpis"] = parsedJSON
93-
fileName = "{}Hyperparameters.json".format(
94-
currentJSON["kpis"][list(currentJSON["kpis"].keys())[0]]["ModelName"]
100+
models_to_update = kpis["ModelUUID"].unique().tolist()
101+
for model in models_to_update:
102+
current_params = _find_file(model, "hyperparameters")
103+
current_json = current_params.json()
104+
model_rows = kpis.loc[kpis["ModelUUID"] == model]
105+
model_rows.set_index("TimeLabel", inplace=True)
106+
kpi_json = model_rows.to_json(orient="index")
107+
parsed_json = json.loads(kpi_json)
108+
current_json["kpis"] = parsed_json
109+
file_name = "{}Hyperparameters.json".format(
110+
current_json["kpis"][list(current_json["kpis"].keys())[0]]["ModelName"]
95111
)
96112
mr.add_model_content(
97113
model,
98-
StringIO((json.dumps(currentJSON, indent=4))),
99-
fileName,
114+
StringIO(json.dumps(current_json, indent=4)),
115+
file_name,
100116
)
101117

102118
@classmethod
103119
def get_hyperparameters(cls, model):
104120
"""
105-
Gets hyperparameter json file from specified model.
121+
Retrieves the hyperparameter json file from specified model on SAS Model Manager.
122+
106123
Parameters
107124
----------
108-
model : str, dict, list
109-
Name, id, or dictionary representation of the model
125+
model : str or dict
126+
The name or id of the model, or a dictionary representation of the model.
127+
110128
Returns
111129
-------
112130
dict
113-
dictionary containing the contents of the hyperparameter file
131+
Dictionary containing the contents of the hyperparameter file.
114132
"""
115133
if mr.is_uuid(model):
116134
id_ = model
@@ -119,19 +137,20 @@ def get_hyperparameters(cls, model):
119137
else:
120138
model = mr.get_model(model)
121139
id_ = model["id"]
122-
file = find_file(id_, "hyperparameters")
140+
file = _find_file(id_, "hyperparameters")
123141
return file.json()
124142

125143
@classmethod
126144
def add_hyperparameters(cls, model, **kwargs):
127145
"""
128-
Adds custom hyperparameters to the hyperparameter file contained within the model in Model Manager.
146+
Adds custom hyperparameters to the hyperparameter file contained within the model in SAS Model Manager.
147+
129148
Parameters
130149
----------
131-
model : str, dict
132-
name, id, or dictionary representation of the model
150+
model : str or dict
151+
The name or id of the model, or a dictionary representation of the model.
133152
kwargs
134-
named variables representing hyperparameters to be added to the hyperparameter file
153+
Named variables pairs representing hyperparameters to be added to the hyperparameter file.
135154
"""
136155
from io import StringIO
137156

@@ -142,23 +161,6 @@ def add_hyperparameters(cls, model, **kwargs):
142161
hyperparameters["hyperparameters"][key] = value
143162
mr.add_model_content(
144163
model,
145-
StringIO((json.dumps(hyperparameters, indent=4))),
146-
"{}Hyperparameters.json".format(model.name),
164+
StringIO(json.dumps(hyperparameters, indent=4)),
165+
f"{model.name}Hyperparameters.json"
147166
)
148-
149-
def sklearn_params(model, modelPrefix, pPath):
150-
"""
151-
Generates hyperparameters for the models genereated by SciKit Learn.
152-
Parameters
153-
----------
154-
modelPrefix : str
155-
Name used to create model files
156-
pPath : str, Path
157-
Directory location of model files
158-
"""
159-
hyperparameters = model.get_params()
160-
modelJson = {"hyperparameters": hyperparameters}
161-
with open(
162-
Path(pPath) / ("{}Hyperparameters.json".format(modelPrefix)), "w"
163-
) as f:
164-
f.write(json.dumps(modelJson, indent=4))

0 commit comments

Comments
 (0)