Skip to content

Commit e778855

Browse files
committed
Update pickle_model test cases.
1 parent cd8f5e0 commit e778855

File tree

1 file changed

+51
-28
lines changed

1 file changed

+51
-28
lines changed

tests/unit/test_pickle_model.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import tempfile
99
from pathlib import Path
1010

11+
import pytest
12+
1113
MODEL_PREFIX = "UNIT_TEST_MODEL"
1214
MODEL = []
1315

@@ -16,8 +18,8 @@ def test_pickle_trained_model():
1618
"""
1719
Test cases:
1820
- normal
19-
- h2o binary
20-
- h2o mojo
21+
- h2o binary (moved)
22+
- h2o mojo (moved)
2123
- binary string
2224
- mlflow model
2325
@@ -33,31 +35,6 @@ def test_pickle_trained_model():
3335
assert (Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).exists()
3436
(Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).unlink()
3537

36-
# H2O binary model case
37-
f = tempfile.NamedTemporaryFile(delete=False, dir=tmp_dir.name)
38-
shutil.copy(f.name, str(Path(tmp_dir.name) / MODEL_PREFIX))
39-
f.close()
40-
pm.pickle_trained_model(
41-
trained_model=None,
42-
model_prefix=MODEL_PREFIX,
43-
pickle_path=tmp_dir.name,
44-
is_h2o_model=True,
45-
is_binary_model=True,
46-
)
47-
assert (Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).exists()
48-
(Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).unlink()
49-
50-
# H2O MOJO model case
51-
f = tempfile.NamedTemporaryFile(delete=False, dir=tmp_dir.name)
52-
pm.pickle_trained_model(
53-
trained_model=f.name,
54-
model_prefix=MODEL_PREFIX,
55-
pickle_path=tmp_dir.name,
56-
is_h2o_model=True,
57-
)
58-
assert (Path(tmp_dir.name) / (MODEL_PREFIX + ".mojo")).exists()
59-
(Path(tmp_dir.name) / (MODEL_PREFIX + ".mojo")).unlink()
60-
6138
# Binary string case
6239
binary_string = pm.pickle_trained_model(
6340
trained_model=MODEL, model_prefix=MODEL_PREFIX, is_binary_string=True
@@ -70,7 +47,6 @@ def test_pickle_trained_model():
7047
delete=False, dir=mlflow_tmp_dir.name, suffix=".pickle"
7148
)
7249
mlflow_dict = {"mlflowPath": mlflow_tmp_dir.name, "model_path": Path(f.name).name}
73-
# import pdb; pdb.set_trace()
7450
pm.pickle_trained_model(
7551
trained_model=None,
7652
model_prefix=MODEL_PREFIX,
@@ -79,3 +55,50 @@ def test_pickle_trained_model():
7955
)
8056
assert (Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).exists()
8157
(Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).unlink()
58+
59+
60+
def test_pickle_trained_model_h2o():
61+
"""
62+
Side function for h2o models in case h2o is not installed.
63+
"""
64+
h2o = pytest.importorskip("h2o")
65+
from sasctl.pzmm.pickle_model import PickleModel as pm
66+
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
67+
68+
h2o.init()
69+
data = h2o.import_file("examples/data/hmeq.csv")
70+
data["BAD"] = data["BAD"].asfactor()
71+
y = "BAD"
72+
x = list(data.columns)
73+
x.remove(y)
74+
75+
model = H2OGeneralizedLinearEstimator(
76+
family="binomial",
77+
model_id="test_model",
78+
lambda_search=True
79+
)
80+
model.train(x=x, y=y, training_frame=data)
81+
82+
tmp_dir = tempfile.TemporaryDirectory()
83+
# H2O binary model case
84+
pm.pickle_trained_model(
85+
trained_model=model,
86+
model_prefix=MODEL_PREFIX,
87+
pickle_path=tmp_dir.name,
88+
is_h2o_model=True,
89+
is_binary_model=True,
90+
)
91+
assert (Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).exists()
92+
model = h2o.load_model(tmp_dir.name + "/" + MODEL_PREFIX + ".pickle")
93+
(Path(tmp_dir.name) / (MODEL_PREFIX + ".pickle")).unlink()
94+
95+
# H2O MOJO model case
96+
pm.pickle_trained_model(
97+
trained_model=model,
98+
model_prefix=MODEL_PREFIX,
99+
pickle_path=tmp_dir.name,
100+
is_h2o_model=True,
101+
)
102+
assert (Path(tmp_dir.name) / (MODEL_PREFIX + ".mojo")).exists()
103+
model = h2o.import_mojo(tmp_dir.name + "/" + MODEL_PREFIX + ".mojo")
104+
(Path(tmp_dir.name) / (MODEL_PREFIX + ".mojo")).unlink()

0 commit comments

Comments
 (0)