Skip to content

Commit 311bc0c

Browse files
committed
Black reformatting
1 parent bb809f8 commit 311bc0c

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

src/sasctl/pzmm/pickle_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,12 @@ def pickle_trained_model(
112112
model=trained_model,
113113
force=True,
114114
path=pickle_path,
115-
filename=f"{model_prefix}.pickle"
115+
filename=f"{model_prefix}.pickle",
116116
)
117117
# For MOJO H2O models, gzip the model file and adjust the file extension
118118
elif is_h2o_model and pickle_path:
119119
trained_model.save_mojo(
120-
force=True,
121-
path=pickle_path,
122-
filename=f"{model_prefix}.mojo"
120+
force=True, path=pickle_path, filename=f"{model_prefix}.mojo"
123121
)
124122
elif is_binary_model or is_h2o_model:
125123
raise ValueError(

tests/unit/test_pickle_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def test_pickle_trained_model_h2o():
7373
x.remove(y)
7474

7575
model = H2OGeneralizedLinearEstimator(
76-
family="binomial",
77-
model_id="test_model",
78-
lambda_search=True
76+
family="binomial", model_id="test_model", lambda_search=True
7977
)
8078
model.train(x=x, y=y, training_frame=data)
8179

tests/unit/test_write_json_files.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ def test_write_file_metadata_json():
272272
with patch.object(jf, "notebook_output", True):
273273
capture_output = io.StringIO()
274274
sys.stdout = capture_output
275-
jf.write_file_metadata_json(model_prefix="Test_Model", json_path=Path(tmp_dir))
275+
jf.write_file_metadata_json(
276+
model_prefix="Test_Model", json_path=Path(tmp_dir)
277+
)
276278
assert (Path(tmp_dir) / "fileMetadata.json").exists()
277279
sys.stdout = sys.__stdout__
278280
assert "was successfully written and saved to " in capture_output.getvalue()
@@ -636,4 +638,4 @@ def test_create_requirements_json(change_dir):
636638
unittest.TestCase.maxDiff = None
637639
unittest.TestCase().assertCountEqual(
638640
json.loads(json_dict["requirements.json"]), expected
639-
)
641+
)

0 commit comments

Comments
 (0)