Skip to content

Commit 6b5e6e2

Browse files
djm21smlindauer
authored andcommitted
Documentation for modelParameters
1 parent 1a92a22 commit 6b5e6e2

File tree

1 file changed

+68
-2
lines changed

1 file changed

+68
-2
lines changed

src/sasctl/pzmm/modelParameters.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@
55

66

77
def find_file(model, fileName):
8+
"""
9+
Retrieves first file from specified model that contains fileName as a substring.
10+
Parameters
11+
----------
12+
model : str
13+
ID of the model the desired file is located in
14+
fileName : str
15+
The name of the desired file, or a substring that is contained within the file name
16+
Returns
17+
-------
18+
RestObj
19+
The first file with a name containing fileName
20+
21+
"""
822
from ..core import current_session
923

1024
sess = current_session()
@@ -22,6 +36,18 @@ def find_file(model, fileName):
2236
class ModelParameters:
2337
@classmethod
2438
def generate_hyperparameters(cls, model, modelPrefix, pPath):
39+
"""
40+
Generates hyperparameters for a given model
41+
Parameters
42+
----------
43+
model : str, list, dict
44+
Name, id, or dictionary representation of the model
45+
modelPrefix : str
46+
Name used to create model files
47+
e.g. (modelPrefix) + "Hyperparameters.json")
48+
pPath : str, Path
49+
Directory location of model files
50+
"""
2551
if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
2652
cls.sklearn_params(model, modelPrefix, pPath)
2753
else:
@@ -36,7 +62,18 @@ def update_kpis(
3662
server="cas-shared-default",
3763
caslib="ModelPerformanceData",
3864
):
39-
"""Updates"""
65+
"""
66+
Updates hyperparameter file to include KPIs generated by performance definitions, as well
67+
as any custom KPIs imported by user to the SAS KPI data table.
68+
Parameters
69+
----------
70+
project : dict, str, list
71+
Name, id, or dictionary representation of the project
72+
server : str, optional
73+
Server on which the KPI data table is stored. Defaults to "cas-shared-default".
74+
caslib : str, optional
75+
CAS Library on which the KPI data table is stored. Defaults to "ModelPerformanceData".
76+
"""
4077
from ..tasks import get_project_kpis
4178
from io import StringIO
4279

@@ -61,6 +98,17 @@ def update_kpis(
6198

6299
@classmethod
63100
def get_hyperparameters(cls, model):
101+
"""
102+
Gets hyperparameter json file from specified model.
103+
Parameters
104+
----------
105+
model : str, dict, list
106+
Name, id, or dictionary representation of the model
107+
Returns
108+
-------
109+
dict
110+
dictionary containing the contents of the hyperparameter file
111+
"""
64112
if mr.is_uuid(model):
65113
id_ = model
66114
elif isinstance(model, dict) and "id" in model:
@@ -73,6 +121,15 @@ def get_hyperparameters(cls, model):
73121

74122
@classmethod
75123
def add_hyperparameters(cls, model, **kwargs):
124+
"""
125+
Adds custom hyperparameters to the hyperparameter file contained within the model in Model Manager.
126+
Parameters
127+
----------
128+
model : str, dict
129+
name, id, or dictionary representation of the model
130+
kwargs
131+
named variables representing hyperparameters to be added to the hyperparameter file
132+
"""
76133
from io import StringIO
77134

78135
if not isinstance(model, dict):
@@ -87,9 +144,18 @@ def add_hyperparameters(cls, model, **kwargs):
87144
)
88145

89146
def sklearn_params(model, modelPrefix, pPath):
147+
"""
148+
Generates hyperparameters for the models genereated by SciKit Learn.
149+
Parameters
150+
----------
151+
modelPrefix : str
152+
Name used to create model files
153+
pPath : str, Path
154+
Directory location of model files
155+
"""
90156
hyperparameters = model.get_params()
91157
modelJson = {"hyperparameters": hyperparameters}
92158
with open(
93159
Path(pPath) / ("{}Hyperparameters.json".format(modelPrefix)), "w"
94160
) as f:
95-
f.write(json.dumps(modelJson, indent=4))
161+
f.write(json.dumps(modelJson, indent=4))

0 commit comments

Comments
 (0)