Skip to content

Commit ec15849

Browse files
committed
Ensure stanfit objects can be deepcopied and pickled
1 parent 35b411c commit ec15849

File tree

6 files changed

+71
-0
lines changed

6 files changed

+71
-0
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
get_logger,
4444
scan_generated_quantities_csv,
4545
)
46+
4647
from .metadata import InferenceMetadata
4748
from .runset import RunSet
4849

@@ -122,6 +123,8 @@ def __repr__(self) -> str:
122123

123124
def __getattr__(self, attr: str) -> np.ndarray:
124125
"""Synonymous with ``fit.stan_variable(attr)"""
126+
if attr.startswith("_"):
127+
raise AttributeError(f"Unknown variable name {attr}")
125128
try:
126129
return self.stan_variable(attr)
127130
except ValueError as e:
@@ -833,6 +836,8 @@ def __repr__(self) -> str:
833836

834837
def __getattr__(self, attr: str) -> np.ndarray:
835838
"""Synonymous with ``fit.stan_variable(attr)"""
839+
if attr.startswith("_"):
840+
raise AttributeError(f"Unknown variable name {attr}")
836841
try:
837842
return self.stan_variable(attr)
838843
except ValueError as e:

cmdstanpy/stanfit/mle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __repr__(self) -> str:
5252

5353
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
5454
"""Synonymous with ``fit.stan_variable(attr)"""
55+
if attr.startswith("_"):
56+
raise AttributeError(f"Unknown variable name {attr}")
5557
try:
5658
return self.stan_variable(attr)
5759
except ValueError as e:

cmdstanpy/stanfit/vb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __repr__(self) -> str:
4343

4444
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
4545
"""Synonymous with ``fit.stan_variable(attr)"""
46+
if attr.startswith("_"):
47+
raise AttributeError(f"Unknown variable name {attr}")
4648
try:
4749
return self.stan_variable(attr)
4850
except ValueError as e:

test/test_optimize.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,28 @@ def test_attrs(self):
627627
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
628628
dummy = fit.c
629629

630+
def test_pickle_ability(self):
631+
"""Ensure fit objects are pickle-able and copy-able"""
632+
633+
import pickle
634+
635+
csvfiles_path = os.path.join(
636+
DATAFILES_PATH, 'optimize', 'rosenbrock_mle.csv'
637+
)
638+
fit = from_csv(path=csvfiles_path)
639+
pickled = pickle.dumps(fit)
640+
unpickled = pickle.loads(pickled)
641+
self.assertSequenceEqual(
642+
fit.stan_variables().keys(), unpickled.stan_variables().keys()
643+
)
644+
645+
import copy
646+
647+
fit2 = copy.deepcopy(fit)
648+
self.assertSequenceEqual(
649+
fit.stan_variables().keys(), fit2.stan_variables().keys()
650+
)
651+
630652

631653
if __name__ == '__main__':
632654
unittest.main()

test/test_sample.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,6 +1881,26 @@ def test_diagnostics(self):
18811881
self.assertEqual(fit.max_treedepths, None)
18821882
self.assertEqual(fit.divergences, None)
18831883

1884+
def test_pickle_ability(self):
1885+
"""Ensure fit objects are pickle-able and copy-able"""
1886+
1887+
import pickle
1888+
1889+
csvfiles_path = os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')
1890+
fit = from_csv(path=csvfiles_path)
1891+
pickled = pickle.dumps(fit)
1892+
unpickled = pickle.loads(pickled)
1893+
self.assertSequenceEqual(
1894+
fit.stan_variables().keys(), unpickled.stan_variables().keys()
1895+
)
1896+
1897+
import copy
1898+
1899+
fit2 = copy.deepcopy(fit)
1900+
self.assertSequenceEqual(
1901+
fit.stan_variables().keys(), fit2.stan_variables().keys()
1902+
)
1903+
18841904

18851905
if __name__ == '__main__':
18861906
unittest.main()

test/test_variational.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,26 @@ def test_attrs(self):
284284
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
285285
dummy = fit.c
286286

287+
def test_pickle_ability(self):
288+
"""Ensure fit objects are pickle-able and copy-able"""
289+
290+
import pickle
291+
292+
csvfiles_path = os.path.join(DATAFILES_PATH, 'variational')
293+
fit = from_csv(path=csvfiles_path)
294+
pickled = pickle.dumps(fit)
295+
unpickled = pickle.loads(pickled)
296+
self.assertSequenceEqual(
297+
fit.stan_variables().keys(), unpickled.stan_variables().keys()
298+
)
299+
300+
import copy
301+
302+
fit2 = copy.deepcopy(fit)
303+
self.assertSequenceEqual(
304+
fit.stan_variables().keys(), fit2.stan_variables().keys()
305+
)
306+
287307

288308
if __name__ == '__main__':
289309
unittest.main()

0 commit comments

Comments
 (0)