Skip to content

Commit 69ac3e7

Browse files
committed
Add serialization test for optimize.
1 parent be5c1e9 commit 69ac3e7

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

test/test_optimize.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io
55
import json
66
import os
7+
import pickle
78
import shutil
89
import unittest
910

@@ -649,6 +650,32 @@ def test_timeout(self):
649650
with self.assertRaises(TimeoutError):
650651
timeout_model.optimize(data={'loop': 1}, timeout=0.1)
651652

653+
def test_serialization(self):
654+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
655+
model = CmdStanModel(stan_file=stan)
656+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
657+
jinit = os.path.join(DATAFILES_PATH, 'bernoulli.init.json')
658+
mle1 = model.optimize(
659+
data=jdata,
660+
seed=1239812093,
661+
inits=jinit,
662+
algorithm='LBFGS',
663+
init_alpha=0.001,
664+
iter=100,
665+
tol_obj=1e-12,
666+
tol_rel_obj=1e4,
667+
tol_grad=1e-8,
668+
tol_rel_grad=1e7,
669+
tol_param=1e-8,
670+
history_size=5,
671+
)
672+
dumped = pickle.dumps(mle1)
673+
shutil.rmtree(mle1.runset._output_dir)
674+
mle2: CmdStanMLE = pickle.loads(dumped)
675+
np.testing.assert_array_equal(
676+
mle1.optimized_params_np, mle2.optimized_params_np
677+
)
678+
652679

653680
if __name__ == '__main__':
654681
unittest.main()

0 commit comments

Comments
 (0)