Skip to content

Commit b382cc3

Browse files
authored
Merge pull request #632 from tillahoffmann/getstate
Add serialization of samples.
2 parents 061c438 + 92a83f2 commit b382cc3

File tree

6 files changed

+121
-18
lines changed

6 files changed

+121
-18
lines changed

cmdstanpy/stanfit/gq.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def __getattr__(self, attr: str) -> np.ndarray:
8585
# pylint: disable=raise-missing-from
8686
raise AttributeError(*e.args)
8787

88+
def __getstate__(self) -> dict:
89+
self._assemble_generated_quantities()
90+
return self.__dict__
91+
8892
def _validate_csv_files(self) -> Dict[str, Any]:
8993
"""
9094
Checks that Stan CSV output files for all chains are consistent
@@ -189,8 +193,7 @@ def draws(
189193
CmdStanGQ.draws_xr
190194
CmdStanMCMC.draws
191195
"""
192-
if self._draws.shape == (0,):
193-
self._assemble_generated_quantities()
196+
self._assemble_generated_quantities()
194197
if (
195198
inc_warmup
196199
and not self.mcmc_sample.metadata.cmdstan_config['save_warmup']
@@ -277,8 +280,7 @@ def draws_pd(
277280
'Draws from warmup iterations not available,'
278281
' must run sampler with "save_warmup=True".'
279282
)
280-
if self._draws.shape == (0,):
281-
self._assemble_generated_quantities()
283+
self._assemble_generated_quantities()
282284

283285
gq_cols = []
284286
mcmc_vars = []
@@ -400,8 +402,7 @@ def draws_xr(
400402
for var in dup_vars:
401403
vars_list.remove(var)
402404

403-
if self._draws.shape == (0,):
404-
self._assemble_generated_quantities()
405+
self._assemble_generated_quantities()
405406

406407
num_draws = self.mcmc_sample.num_draws_sampling
407408
sample_config = self.mcmc_sample.metadata.cmdstan_config
@@ -505,8 +506,7 @@ def stan_variable(
505506
if var not in gq_var_names:
506507
return self.mcmc_sample.stan_variable(var, inc_warmup=inc_warmup)
507508
else: # is gq variable
508-
if self._draws.shape == (0,):
509-
self._assemble_generated_quantities()
509+
self._assemble_generated_quantities()
510510
draw1 = 0
511511
if (
512512
not inc_warmup
@@ -561,6 +561,8 @@ def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
561561
return result
562562

563563
def _assemble_generated_quantities(self) -> None:
564+
if self._draws.shape != (0,):
565+
return
564566
# use numpy loadtxt
565567
warmup = self.mcmc_sample.metadata.cmdstan_config['save_warmup']
566568
num_draws = self.mcmc_sample.draws(inc_warmup=warmup).shape[0]

cmdstanpy/stanfit/mcmc.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ def __getattr__(self, attr: str) -> np.ndarray:
129129
# pylint: disable=raise-missing-from
130130
raise AttributeError(*e.args)
131131

132+
def __getstate__(self) -> dict:
133+
# This function returns the mapping of objects to serialize with pickle.
134+
# See https://docs.python.org/3/library/pickle.html#object.__getstate__
135+
# for details. We call _assemble_draws to ensure posterior samples have
136+
# been loaded prior to serialization.
137+
self._assemble_draws()
138+
return self.__dict__
139+
132140
@property
133141
def chains(self) -> int:
134142
"""Number of chains."""
@@ -259,8 +267,7 @@ def draws(
259267
CmdStanMCMC.draws_xr
260268
CmdStanGQ.draws
261269
"""
262-
if self._draws.shape == (0,):
263-
self._assemble_draws()
270+
self._assemble_draws()
264271

265272
if inc_warmup and not self._save_warmup:
266273
get_logger().warning(
@@ -591,8 +598,7 @@ def draws_pd(
591598
' must run sampler with "save_warmup=True".'
592599
)
593600

594-
if self._draws.shape == (0,):
595-
self._assemble_draws()
601+
self._assemble_draws()
596602
cols = []
597603
if vars is not None:
598604
for var in dict.fromkeys(vars_list):
@@ -648,8 +654,7 @@ def draws_xr(
648654
else:
649655
vars_list = vars
650656

651-
if self._draws.shape == (0,):
652-
self._assemble_draws()
657+
self._assemble_draws()
653658

654659
num_draws = self.num_draws_sampling
655660
meta = self._metadata.cmdstan_config
@@ -735,8 +740,7 @@ def stan_variable(
735740
'Available variables are '
736741
+ ", ".join(self._metadata.stan_vars_dims)
737742
)
738-
if self._draws.shape == (0,):
739-
self._assemble_draws()
743+
self._assemble_draws()
740744
draw1 = 0
741745
if not inc_warmup and self._save_warmup:
742746
draw1 = self.num_draws_warmup
@@ -783,8 +787,7 @@ def method_variables(self) -> Dict[str, np.ndarray]:
783787
containing per-draw diagnostic values.
784788
"""
785789
result = {}
786-
if self._draws.shape == (0,):
787-
self._assemble_draws()
790+
self._assemble_draws()
788791
for idxs in self.metadata.method_vars_cols.values():
789792
for idx in idxs:
790793
result[self.column_names[idx]] = self._draws[:, :, idx]

test/test_generate_quantities.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55
import json
66
import logging
77
import os
8+
import pickle
9+
import shutil
810
import unittest
911
from test import CustomTestCase
1012

1113
import numpy as np
1214
import pandas as pd
15+
import pytest
1316
from numpy.testing import assert_array_equal, assert_raises
1417
from testfixtures import LogCapture
1518

1619
import cmdstanpy.stanfit
20+
from cmdstanpy.stanfit import CmdStanGQ
1721
from cmdstanpy.cmdstan_args import Method
1822
from cmdstanpy.model import CmdStanModel
1923

@@ -485,6 +489,26 @@ def test_timeout(self):
485489
timeout=0.1, mcmc_sample=fit, data={'loop': 1}
486490
)
487491

492+
@pytest.mark.order(before="test_no_xarray")
493+
def test_serialization(self):
494+
stan_bern = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
495+
model_bern = CmdStanModel(stan_file=stan_bern)
496+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
497+
fit_sampling = model_bern.sample(chains=1, iter_sampling=10, data=jdata)
498+
499+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
500+
model = CmdStanModel(stan_file=stan)
501+
fit1 = model.generate_quantities(data=jdata, mcmc_sample=fit_sampling)
502+
503+
dumped = pickle.dumps(fit1)
504+
shutil.rmtree(fit1.runset._output_dir)
505+
fit2: CmdStanGQ = pickle.loads(dumped)
506+
variables1 = fit1.stan_variables()
507+
variables2 = fit2.stan_variables()
508+
self.assertEqual(set(variables1), set(variables2))
509+
for key, value1 in variables1.items():
510+
np.testing.assert_array_equal(value1, variables2[key])
511+
488512

489513
if __name__ == '__main__':
490514
unittest.main()

test/test_optimize.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io
55
import json
66
import os
7+
import pickle
78
import shutil
89
import unittest
910

@@ -649,6 +650,32 @@ def test_timeout(self):
649650
with self.assertRaises(TimeoutError):
650651
timeout_model.optimize(data={'loop': 1}, timeout=0.1)
651652

653+
def test_serialization(self):
654+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
655+
model = CmdStanModel(stan_file=stan)
656+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
657+
jinit = os.path.join(DATAFILES_PATH, 'bernoulli.init.json')
658+
mle1 = model.optimize(
659+
data=jdata,
660+
seed=1239812093,
661+
inits=jinit,
662+
algorithm='LBFGS',
663+
init_alpha=0.001,
664+
iter=100,
665+
tol_obj=1e-12,
666+
tol_rel_obj=1e4,
667+
tol_grad=1e-8,
668+
tol_rel_grad=1e7,
669+
tol_param=1e-8,
670+
history_size=5,
671+
)
672+
dumped = pickle.dumps(mle1)
673+
shutil.rmtree(mle1.runset._output_dir)
674+
mle2: CmdStanMLE = pickle.loads(dumped)
675+
np.testing.assert_array_equal(
676+
mle1.optimized_params_np, mle2.optimized_params_np
677+
)
678+
652679

653680
if __name__ == '__main__':
654681
unittest.main()

test/test_sample.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import tempfile
1212
import unittest
1313
from multiprocessing import cpu_count
14+
import pickle
1415
from test import CustomTestCase
1516
from time import time
1617

1718
import numpy as np
19+
import pytest
1820
from testfixtures import LogCapture, StringComparison
1921

2022
import cmdstanpy.stanfit
@@ -1928,6 +1930,33 @@ def test_json_edges(self):
19281930
self.assertTrue(np.isnan(fit.stan_variable("nan_out")[0]))
19291931
self.assertTrue(np.isinf(fit.stan_variable("inf_out")[0]))
19301932

1933+
@pytest.mark.order(before="test_no_xarray")
1934+
def test_serialization(self, stanfile='bernoulli.stan'):
1935+
# This test must before any test that uses the `without_import` context
1936+
# manager because the latter uses `reload` with side effects that affect
1937+
# the consistency of classes.
1938+
stan = os.path.join(DATAFILES_PATH, stanfile)
1939+
bern_model = CmdStanModel(stan_file=stan)
1940+
1941+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
1942+
bern_fit1 = bern_model.sample(
1943+
data=jdata,
1944+
chains=1,
1945+
iter_warmup=200,
1946+
iter_sampling=100,
1947+
show_progress=False,
1948+
)
1949+
# Dump the result (which assembles draws) and delete the source files.
1950+
dumped = pickle.dumps(bern_fit1)
1951+
shutil.rmtree(bern_fit1.runset._output_dir)
1952+
# Load the serialized result and compare results.
1953+
bern_fit2: CmdStanMCMC = pickle.loads(dumped)
1954+
variables1 = bern_fit1.stan_variables()
1955+
variables2 = bern_fit2.stan_variables()
1956+
self.assertEqual(set(variables1), set(variables2))
1957+
for key, value1 in variables1.items():
1958+
np.testing.assert_array_equal(value1, variables2[key])
1959+
19311960

19321961
if __name__ == '__main__':
19331962
unittest.main()

test/test_variational.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import io
55
import os
6+
import pickle
67
import shutil
78
import unittest
89
from math import fabs
@@ -297,6 +298,23 @@ def test_timeout(self):
297298
with self.assertRaises(TimeoutError):
298299
timeout_model.variational(timeout=0.1, data={'loop': 1})
299300

301+
def test_serialization(self):
302+
stan = os.path.join(
303+
DATAFILES_PATH, 'variational', 'eta_should_be_big.stan'
304+
)
305+
model = CmdStanModel(stan_file=stan)
306+
variational1 = model.variational(algorithm='meanfield', seed=999999)
307+
dumped = pickle.dumps(variational1)
308+
shutil.rmtree(variational1.runset._output_dir)
309+
variational2: CmdStanVB = pickle.loads(dumped)
310+
np.testing.assert_array_equal(
311+
variational1.variational_sample, variational2.variational_sample
312+
)
313+
self.assertEqual(
314+
variational1.variational_params_dict,
315+
variational2.variational_params_dict,
316+
)
317+
300318

301319
if __name__ == '__main__':
302320
unittest.main()

0 commit comments

Comments
 (0)