Skip to content

Commit 2eccbda

Browse files
committed
Add serialization test for variational.
1 parent 69ac3e7 commit 2eccbda

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

test/test_variational.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import io
55
import os
6+
import pickle
67
import shutil
78
import unittest
89
from math import fabs
@@ -297,6 +298,23 @@ def test_timeout(self):
297298
with self.assertRaises(TimeoutError):
298299
timeout_model.variational(timeout=0.1, data={'loop': 1})
299300

301+
def test_serialization(self):
302+
stan = os.path.join(
303+
DATAFILES_PATH, 'variational', 'eta_should_be_big.stan'
304+
)
305+
model = CmdStanModel(stan_file=stan)
306+
variational1 = model.variational(algorithm='meanfield', seed=999999)
307+
dumped = pickle.dumps(variational1)
308+
shutil.rmtree(variational1.runset._output_dir)
309+
variational2: CmdStanVB = pickle.loads(dumped)
310+
np.testing.assert_array_equal(
311+
variational1.variational_sample, variational2.variational_sample
312+
)
313+
self.assertEqual(
314+
variational1.variational_params_dict,
315+
variational2.variational_params_dict,
316+
)
317+
300318

301319
if __name__ == '__main__':
302320
unittest.main()

0 commit comments

Comments
 (0)