Skip to content

Commit f8b8610

Browse files
committed
Add serialization of samples.
1 parent 061c438 commit f8b8610

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ def __getattr__(self, attr: str) -> np.ndarray:
129129
# pylint: disable=raise-missing-from
130130
raise AttributeError(*e.args)
131131

132+
def __getstate__(self) -> dict:
133+
self._assemble_draws()
134+
return self.__dict__
135+
132136
@property
133137
def chains(self) -> int:
134138
"""Number of chains."""
@@ -259,8 +263,7 @@ def draws(
259263
CmdStanMCMC.draws_xr
260264
CmdStanGQ.draws
261265
"""
262-
if self._draws.shape == (0,):
263-
self._assemble_draws()
266+
self._assemble_draws()
264267

265268
if inc_warmup and not self._save_warmup:
266269
get_logger().warning(
@@ -591,8 +594,7 @@ def draws_pd(
591594
' must run sampler with "save_warmup=True".'
592595
)
593596

594-
if self._draws.shape == (0,):
595-
self._assemble_draws()
597+
self._assemble_draws()
596598
cols = []
597599
if vars is not None:
598600
for var in dict.fromkeys(vars_list):
@@ -648,8 +650,7 @@ def draws_xr(
648650
else:
649651
vars_list = vars
650652

651-
if self._draws.shape == (0,):
652-
self._assemble_draws()
653+
self._assemble_draws()
653654

654655
num_draws = self.num_draws_sampling
655656
meta = self._metadata.cmdstan_config
@@ -735,8 +736,7 @@ def stan_variable(
735736
'Available variables are '
736737
+ ", ".join(self._metadata.stan_vars_dims)
737738
)
738-
if self._draws.shape == (0,):
739-
self._assemble_draws()
739+
self._assemble_draws()
740740
draw1 = 0
741741
if not inc_warmup and self._save_warmup:
742742
draw1 = self.num_draws_warmup
@@ -783,8 +783,7 @@ def method_variables(self) -> Dict[str, np.ndarray]:
783783
containing per-draw diagnostic values.
784784
"""
785785
result = {}
786-
if self._draws.shape == (0,):
787-
self._assemble_draws()
786+
self._assemble_draws()
788787
for idxs in self.metadata.method_vars_cols.values():
789788
for idx in idxs:
790789
result[self.column_names[idx]] = self._draws[:, :, idx]

test/test_sample.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
import unittest
1313
from multiprocessing import cpu_count
14+
import pickle
1415
from test import CustomTestCase
1516
from time import time
1617

@@ -1928,6 +1929,30 @@ def test_json_edges(self):
19281929
self.assertTrue(np.isnan(fit.stan_variable("nan_out")[0]))
19291930
self.assertTrue(np.isinf(fit.stan_variable("inf_out")[0]))
19301931

1932+
def test_serialization(self, stanfile='bernoulli.stan'):
1933+
stan = os.path.join(DATAFILES_PATH, stanfile)
1934+
bern_model = CmdStanModel(stan_file=stan)
1935+
1936+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
1937+
bern_fit1 = bern_model.sample(
1938+
data=jdata,
1939+
chains=1,
1940+
iter_warmup=200,
1941+
iter_sampling=100,
1942+
show_progress=False,
1943+
)
1944+
# Dump the result (which assembles draws) and delete the source files.
1945+
dumped = pickle.dumps(bern_fit1)
1946+
for filename in bern_fit1.runset.csv_files:
1947+
os.unlink(filename)
1948+
# Load the serialized result and compare results.
1949+
bern_fit2: CmdStanMCMC = pickle.loads(dumped)
1950+
variables1 = bern_fit1.stan_variables()
1951+
variables2 = bern_fit2.stan_variables()
1952+
self.assertEqual(set(variables1), set(variables2))
1953+
for key, value1 in variables1.items():
1954+
np.testing.assert_array_equal(value1, variables2[key])
1955+
19311956

19321957
if __name__ == '__main__':
19331958
unittest.main()

0 commit comments

Comments
 (0)