File tree Expand file tree Collapse file tree 1 file changed +18
-0
lines changed
Expand file tree Collapse file tree 1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change 33import contextlib
44import io
55import os
6+ import pickle
67import shutil
78import unittest
89from math import fabs
@@ -297,6 +298,23 @@ def test_timeout(self):
297298 with self .assertRaises (TimeoutError ):
298299 timeout_model .variational (timeout = 0.1 , data = {'loop' : 1 })
299300
301+ def test_serialization (self ):
302+ stan = os .path .join (
303+ DATAFILES_PATH , 'variational' , 'eta_should_be_big.stan'
304+ )
305+ model = CmdStanModel (stan_file = stan )
306+ variational1 = model .variational (algorithm = 'meanfield' , seed = 999999 )
307+ dumped = pickle .dumps (variational1 )
308+ shutil .rmtree (variational1 .runset ._output_dir )
309+ variational2 : CmdStanVB = pickle .loads (dumped )
310+ np .testing .assert_array_equal (
311+ variational1 .variational_sample , variational2 .variational_sample
312+ )
313+ self .assertEqual (
314+ variational1 .variational_params_dict ,
315+ variational2 .variational_params_dict ,
316+ )
317+
300318
301319if __name__ == '__main__' :
302320 unittest .main ()
You can’t perform that action at this time.
0 commit comments