Skip to content

Commit 1eed427

Browse files
committed
Reformat to black style
1 parent c3e260a commit 1eed427

File tree

5 files changed

+921
-584
lines changed

5 files changed

+921
-584
lines changed

src/sasctl/pzmm/importModel.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def project_exists(response, project):
15-
'''Checks if project exists on SAS Viya. If the project does not exist, then a new
15+
"""Checks if project exists on SAS Viya. If the project does not exist, then a new
1616
project is created or an error is raised.
1717
1818
Parameters
@@ -31,26 +31,26 @@ def project_exists(response, project):
3131
------
3232
SystemError
3333
Alerts user that API calls cannot continue until a valid project is provided.
34-
'''
34+
"""
3535
if response is None:
3636
try:
37-
warn('No project with the name or UUID {} was found.'.format(project))
37+
warn("No project with the name or UUID {} was found.".format(project))
3838
UUID(project)
3939
raise SystemError(
40-
'The provided UUID does not match any projects found in SAS Model Manager. '
41-
+ 'Please enter a valid UUID or a new name for a project to be created.'
40+
"The provided UUID does not match any projects found in SAS Model Manager. "
41+
+ "Please enter a valid UUID or a new name for a project to be created."
4242
)
4343
except ValueError:
44-
repo = mr.default_repository().get('id')
44+
repo = mr.default_repository().get("id")
4545
response = mr.create_project(project, repo)
46-
print('A new project named {} was created.'.format(response.name))
46+
print("A new project named {} was created.".format(response.name))
4747
return response
4848
else:
4949
return response
5050

5151

5252
def model_exists(project, name, force):
53-
'''Checks if model already exists and either raises an error or deletes the redundant model.
53+
"""Checks if model already exists and either raises an error or deletes the redundant model.
5454
5555
Parameters
5656
----------
@@ -66,30 +66,30 @@ def model_exists(project, name, force):
6666
ValueError
6767
Model repository API cannot overwrite an already existing model with the upload model call.
6868
Alerts user of the force argument to allow multi-call API overwriting.
69-
'''
69+
"""
7070
project = mr.get_project(project)
71-
projectId = project['id']
72-
projectModels = mr.get('/projects/{}/models'.format(projectId))
71+
projectId = project["id"]
72+
projectModels = mr.get("/projects/{}/models".format(projectId))
7373

7474
for model in projectModels:
7575
# Throws a TypeError if only one model is in the project
7676
try:
77-
if model['name'] == name:
77+
if model["name"] == name:
7878
if force:
7979
mr.delete_model(model.id)
8080
else:
8181
raise ValueError(
82-
'A model with the same model name exists in project {}. Include the force=True argument to overwrite models with the same name.'.format(
82+
"A model with the same model name exists in project {}. Include the force=True argument to overwrite models with the same name.".format(
8383
project.name
8484
)
8585
)
8686
except TypeError:
87-
if projectModels['name'] == name:
87+
if projectModels["name"] == name:
8888
if force:
8989
mr.delete_model(projectModels.id)
9090
else:
9191
raise ValueError(
92-
'A model with the same model name exists in project {}. Include the force=True argument to overwrite models with the same name.'.format(
92+
"A model with the same model name exists in project {}. Include the force=True argument to overwrite models with the same name.".format(
9393
project.name
9494
)
9595
)
@@ -105,7 +105,7 @@ def pzmmImportModel(
105105
inputDF,
106106
targetDF,
107107
predictmethod,
108-
metrics=['EM_EVENTPROBABILITY', 'EM_CLASSIFICATION'],
108+
metrics=["EM_EVENTPROBABILITY", "EM_CLASSIFICATION"],
109109
modelFileName=None,
110110
pyPath=None,
111111
threshPrediction=None,
@@ -114,7 +114,7 @@ def pzmmImportModel(
114114
force=False,
115115
binaryString=None,
116116
):
117-
'''Import model to SAS Model Manager using pzmm submodule.
117+
"""Import model to SAS Model Manager using pzmm submodule.
118118
119119
Using pzmm, generate Python score code and import the model files into
120120
SAS Model Manager. This function automatically checks the version of SAS
@@ -172,7 +172,7 @@ def pzmmImportModel(
172172
Sets whether to overwrite models with the same name upon upload. By default False.
173173
binaryString : string, optional
174174
Binary string representation of the model object. By default None.
175-
'''
175+
"""
176176
# Initialize no score code or binary H2O model flags
177177
noScoreCode = False
178178
binaryModel = False
@@ -192,36 +192,36 @@ def getFiles(extensions):
192192
# If the model file name is not provided, set a default value depending on H2O and binary model status
193193
if modelFileName is None:
194194
if isH2OModel:
195-
binaryOrMOJO = getFiles(['*.mojo', '*.pickle'])
195+
binaryOrMOJO = getFiles(["*.mojo", "*.pickle"])
196196
if len(binaryOrMOJO) == 0:
197197
print(
198-
'WARNING: An H2O model file was not found at {}. Score code will not be automatically generated.'.format(
198+
"WARNING: An H2O model file was not found at {}. Score code will not be automatically generated.".format(
199199
str(pyPath)
200200
)
201201
)
202202
noScoreCode = True
203203
elif len(binaryOrMOJO) == 1:
204-
if str(binaryOrMOJO[0]).endswith('.pickle'):
204+
if str(binaryOrMOJO[0]).endswith(".pickle"):
205205
binaryModel = True
206-
modelFileName = modelPrefix + '.pickle'
206+
modelFileName = modelPrefix + ".pickle"
207207
else:
208-
modelFileName = modelPrefix + '.mojo'
208+
modelFileName = modelPrefix + ".mojo"
209209
else:
210210
print(
211-
'WARNING: Both a MOJO and binary model file are present at {}. Score code will not be automatically generated.'.format(
211+
"WARNING: Both a MOJO and binary model file are present at {}. Score code will not be automatically generated.".format(
212212
str(pyPath)
213213
)
214214
)
215215
noScoreCode = True
216216
else:
217-
modelFileName = modelPrefix + '.pickle'
217+
modelFileName = modelPrefix + ".pickle"
218218

219219
# Check the SAS Viya version number being used
220-
isViya35 = platform_version() == '3.5'
220+
isViya35 = platform_version() == "3.5"
221221
# For SAS Viya 4, the score code can be written beforehand and imported with all of the model files
222222
if not isViya35:
223223
if noScoreCode:
224-
print('No score code was generated.')
224+
print("No score code was generated.")
225225
else:
226226
sc.writeScoreCode(
227227
inputDF,
@@ -238,12 +238,12 @@ def getFiles(extensions):
238238
binaryString=binaryString,
239239
)
240240
print(
241-
'Model score code was written successfully to {}.'.format(
242-
Path(pyPath) / (modelPrefix + 'Score.py')
241+
"Model score code was written successfully to {}.".format(
242+
Path(pyPath) / (modelPrefix + "Score.py")
243243
)
244244
)
245245
zipIOFile = zm.zipFiles(Path(zPath), modelPrefix)
246-
print('All model files were zipped to {}.'.format(Path(zPath)))
246+
print("All model files were zipped to {}.".format(Path(zPath)))
247247

248248
# Check if project name provided exists and raise an error or create a new project
249249
projectResponse = mr.get_project(project)
@@ -255,16 +255,16 @@ def getFiles(extensions):
255255
response = mr.import_model_from_zip(modelPrefix, project, zipIOFile)
256256
try:
257257
print(
258-
'Model was successfully imported into SAS Model Manager as {} with UUID: {}.'.format(
258+
"Model was successfully imported into SAS Model Manager as {} with UUID: {}.".format(
259259
response.name, response.id
260260
)
261261
)
262262
except AttributeError:
263-
print('Model failed to import to SAS Model Manager.')
263+
print("Model failed to import to SAS Model Manager.")
264264
# For SAS Viya 3.5, the score code is written after upload in order to know the model UUID
265265
else:
266266
zipIOFile = zm.zipFiles(Path(zPath), modelPrefix)
267-
print('All model files were zipped to {}.'.format(Path(zPath)))
267+
print("All model files were zipped to {}.".format(Path(zPath)))
268268

269269
# Check if project name provided exists and raise an error or create a new project
270270
projectResponse = mr.get_project(project)
@@ -276,14 +276,14 @@ def getFiles(extensions):
276276
response = mr.import_model_from_zip(modelPrefix, project, zipIOFile, force)
277277
try:
278278
print(
279-
'Model was successfully imported into SAS Model Manager as {} with UUID: {}.'.format(
279+
"Model was successfully imported into SAS Model Manager as {} with UUID: {}.".format(
280280
response.name, response.id
281281
)
282282
)
283283
except AttributeError:
284-
print('Model failed to import to SAS Model Manager.')
284+
print("Model failed to import to SAS Model Manager.")
285285
if noScoreCode:
286-
print('No score code was generated.')
286+
print("No score code was generated.")
287287
else:
288288
sc.writeScoreCode(
289289
inputDF,
@@ -301,7 +301,7 @@ def getFiles(extensions):
301301
binaryString=binaryString,
302302
)
303303
print(
304-
'Model score code was written successfully to {} and uploaded to SAS Model Manager'.format(
305-
Path(pyPath) / (modelPrefix + 'Score.py')
304+
"Model score code was written successfully to {} and uploaded to SAS Model Manager".format(
305+
Path(pyPath) / (modelPrefix + "Score.py")
306306
)
307307
)

src/sasctl/pzmm/pickleModel.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,78 @@
1010
import codecs
1111

1212
# %%
13-
class PickleModel():
14-
15-
def pickleTrainedModel(self, trainedModel, modelPrefix, pPath=Path.cwd(), isH2OModel=False, isBinaryModel=False, isBinaryString=False):
16-
'''
17-
Write trained model to a binary pickle file, H2O MOJO file, or a binary string object.
18-
13+
class PickleModel:
14+
def pickleTrainedModel(
15+
self,
16+
trainedModel,
17+
modelPrefix,
18+
pPath=Path.cwd(),
19+
isH2OModel=False,
20+
isBinaryModel=False,
21+
isBinaryString=False,
22+
):
23+
"""
24+
Write trained model to a binary pickle file, H2O MOJO file, or a binary string object.
25+
1926
The following files are generated by this function:
20-
* '*.pickle'
21-
Binary pickle file containing a trained model.
27+
* '*.pickle'
28+
Binary pickle file containing a trained model.
2229
* '*.mojo'
2330
Archived H2O.ai MOJO file containing a trained model.
24-
31+
2532
Parameters
2633
---------------
2734
trainedModel : model or string or Path
2835
For non-H2O models, this argument contains the model variable. Otherwise,
2936
this should be the file path of the MOJO file.
3037
modelPrefix : string
31-
Variable name for the model to be displayed in SAS Open Model Manager
38+
Variable name for the model to be displayed in SAS Open Model Manager
3239
(i.e. hmeqClassTree + [Score.py || .pickle]).
3340
pPath : string, optional
3441
File location for the output pickle file. Default is the current
3542
working directory.
3643
isH2OModel : boolean, optional
37-
Sets whether the model file is an H2O.ai MOJO file. If set as True,
44+
Sets whether the model file is an H2O.ai MOJO file. If set as True,
3845
the MOJO file will be gzipped before uploading to SAS Model Manager.
3946
The default value is False.
4047
isBinaryModel : boolean, optional
4148
Sets whether the H2O model provided is a binary model or a MOJO model. By default False.
4249
isBinaryString : boolean, optional
4350
Sets whether the model is to be set as a binary string instead of a pickle file. By default False.
44-
51+
4552
Returns
4653
-------
4754
binaryString : binary string
4855
When the isBinaryString flag is set to True, return a binary string representation of the model instead
4956
of a pickle or MOJO file.
5057
51-
'''
52-
58+
"""
59+
5360
if isBinaryString:
54-
binaryString = codecs.encode(pickle.dumps(trainedModel), 'base64').decode()
61+
binaryString = codecs.encode(pickle.dumps(trainedModel), "base64").decode()
5562
return binaryString
5663
else:
5764
# For non-H2O models, pickle the model object
5865
if not isH2OModel:
59-
with open(Path(pPath) / (modelPrefix + '.pickle'), 'wb') as pFile:
66+
with open(Path(pPath) / (modelPrefix + ".pickle"), "wb") as pFile:
6067
pickle.dump(trainedModel, pFile)
61-
print('Model {} was successfully pickled and saved to {}.'.format(modelPrefix, Path(pPath) / (modelPrefix + '.pickle')))
68+
print(
69+
"Model {} was successfully pickled and saved to {}.".format(
70+
modelPrefix, Path(pPath) / (modelPrefix + ".pickle")
71+
)
72+
)
6273
# For H2O models that are binary files, rename the binary file as a pickle file
6374
elif isBinaryModel:
6475
binaryFile = Path(pPath) / modelPrefix
65-
binaryFile.rename(binaryFile.with_suffix('.pickle'))
76+
binaryFile.rename(binaryFile.with_suffix(".pickle"))
6677
# For H2O models in the MOJO format, gzip the model file and rename it with a .MOJO extension
6778
else:
68-
with open(Path(trainedModel), 'rb') as fileIn, gzip.open(Path(pPath) / (modelPrefix + '.mojo'), 'wb') as fileOut:
79+
with open(Path(trainedModel), "rb") as fileIn, gzip.open(
80+
Path(pPath) / (modelPrefix + ".mojo"), "wb"
81+
) as fileOut:
6982
fileOut.writelines(fileIn)
70-
print('MOJO model {} was successfully gzipped and saved to {}.'.format(modelPrefix, Path(pPath) / (modelPrefix + '.mojo')))
83+
print(
84+
"MOJO model {} was successfully gzipped and saved to {}.".format(
85+
modelPrefix, Path(pPath) / (modelPrefix + ".mojo")
86+
)
87+
)

0 commit comments

Comments
 (0)