File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed
Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change 44import io
55import json
66import os
7+ import pickle
78import shutil
89import unittest
910
@@ -649,6 +650,32 @@ def test_timeout(self):
649650 with self .assertRaises (TimeoutError ):
650651 timeout_model .optimize (data = {'loop' : 1 }, timeout = 0.1 )
651652
653+ def test_serialization (self ):
654+ stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
655+ model = CmdStanModel (stan_file = stan )
656+ jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
657+ jinit = os .path .join (DATAFILES_PATH , 'bernoulli.init.json' )
658+ mle1 = model .optimize (
659+ data = jdata ,
660+ seed = 1239812093 ,
661+ inits = jinit ,
662+ algorithm = 'LBFGS' ,
663+ init_alpha = 0.001 ,
664+ iter = 100 ,
665+ tol_obj = 1e-12 ,
666+ tol_rel_obj = 1e4 ,
667+ tol_grad = 1e-8 ,
668+ tol_rel_grad = 1e7 ,
669+ tol_param = 1e-8 ,
670+ history_size = 5 ,
671+ )
672+ dumped = pickle .dumps (mle1 )
673+ shutil .rmtree (mle1 .runset ._output_dir )
674+ mle2 : CmdStanMLE = pickle .loads (dumped )
675+ np .testing .assert_array_equal (
676+ mle1 .optimized_params_np , mle2 .optimized_params_np
677+ )
678+
652679
653680if __name__ == '__main__' :
654681 unittest .main ()
You can’t perform that action at this time.
0 commit comments