Skip to content

Commit 8e021ef

Browse files
authored
Merge pull request #583 from stan-dev/fix/pickle-copy-stanfits
Ensure stanfit objects can be deepcopied and pickled
2 parents 6871c4b + fa55feb commit 8e021ef

File tree

5 files changed

+92
-15
lines changed

5 files changed

+92
-15
lines changed

cmdstanpy/stanfit/__init__.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030

3131
def from_csv(
32-
path: Union[str, List[str], None] = None, method: Optional[str] = None
32+
path: Union[str, List[str], os.PathLike, None] = None,
33+
method: Optional[str] = None,
3334
) -> Union[CmdStanMCMC, CmdStanMLE, CmdStanVB, None]:
3435
"""
3536
Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run.
@@ -61,22 +62,22 @@ def from_csv(
6162
csvfiles = []
6263
if isinstance(path, list):
6364
csvfiles = path
64-
elif isinstance(path, str):
65-
if '*' in path:
66-
splits = os.path.split(path)
67-
if splits[0] is not None:
68-
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
69-
raise ValueError(
70-
'Invalid path specification, {} '
71-
' unknown directory: {}'.format(path, splits[0])
72-
)
73-
csvfiles = glob.glob(path)
74-
elif os.path.exists(path) and os.path.isdir(path):
65+
elif isinstance(path, str) and '*' in path:
66+
splits = os.path.split(path)
67+
if splits[0] is not None:
68+
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
69+
raise ValueError(
70+
'Invalid path specification, {} '
71+
' unknown directory: {}'.format(path, splits[0])
72+
)
73+
csvfiles = glob.glob(path)
74+
elif isinstance(path, (str, os.PathLike)):
75+
if os.path.exists(path) and os.path.isdir(path):
7576
for file in os.listdir(path):
76-
if file.endswith(".csv"):
77+
if os.path.splitext(file)[1] == ".csv":
7778
csvfiles.append(os.path.join(path, file))
7879
elif os.path.exists(path):
79-
csvfiles.append(path)
80+
csvfiles.append(str(path))
8081
else:
8182
raise ValueError('Invalid path specification: {}'.format(path))
8283
else:
@@ -85,7 +86,7 @@ def from_csv(
8586
if len(csvfiles) == 0:
8687
raise ValueError('No CSV files found in directory {}'.format(path))
8788
for file in csvfiles:
88-
if not (os.path.exists(file) and file.endswith('.csv')):
89+
if not (os.path.exists(file) and os.path.splitext(file)[1] == ".csv"):
8990
raise ValueError(
9091
'Bad CSV file path spec,'
9192
' includes non-csv file: {}'.format(file)

cmdstanpy/stanfit/mcmc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
get_logger,
4444
scan_generated_quantities_csv,
4545
)
46+
4647
from .metadata import InferenceMetadata
4748
from .runset import RunSet
4849

@@ -122,6 +123,8 @@ def __repr__(self) -> str:
122123

123124
def __getattr__(self, attr: str) -> np.ndarray:
124125
"""Synonymous with ``fit.stan_variable(attr)"""
126+
if attr.startswith("_"):
127+
raise AttributeError(f"Unknown variable name {attr}")
125128
try:
126129
return self.stan_variable(attr)
127130
except ValueError as e:
@@ -833,6 +836,8 @@ def __repr__(self) -> str:
833836

834837
def __getattr__(self, attr: str) -> np.ndarray:
835838
"""Synonymous with ``fit.stan_variable(attr)"""
839+
if attr.startswith("_"):
840+
raise AttributeError(f"Unknown variable name {attr}")
836841
try:
837842
return self.stan_variable(attr)
838843
except ValueError as e:

cmdstanpy/stanfit/mle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __repr__(self) -> str:
5252

5353
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
5454
"""Synonymous with ``fit.stan_variable(attr)"""
55+
if attr.startswith("_"):
56+
raise AttributeError(f"Unknown variable name {attr}")
5557
try:
5658
return self.stan_variable(attr)
5759
except ValueError as e:

cmdstanpy/stanfit/vb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __repr__(self) -> str:
4343

4444
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
4545
"""Synonymous with ``fit.stan_variable(attr)"""
46+
if attr.startswith("_"):
47+
raise AttributeError(f"Unknown variable name {attr}")
4648
try:
4749
return self.stan_variable(attr)
4850
except ValueError as e:

test/test_compliance.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Testing for things like pickleability, deep copying"""
2+
3+
import copy
4+
import pathlib
5+
import pickle
6+
import unittest
7+
8+
import cmdstanpy
9+
10+
DATAFILES_PATH = pathlib.Path(__file__).parent.resolve() / 'data'
11+
12+
13+
class SampleCompliance(unittest.TestCase):
14+
def test_sample_pickle_ability(self):
15+
csvfiles_path = DATAFILES_PATH / 'lotka-volterra.csv'
16+
fit = cmdstanpy.from_csv(path=csvfiles_path)
17+
keys = fit.stan_variables().keys()
18+
pickled = pickle.dumps(fit)
19+
del fit
20+
unpickled = pickle.loads(pickled)
21+
self.assertSequenceEqual(keys, unpickled.stan_variables().keys())
22+
23+
def test_sample_copy_ability(self):
24+
csvfiles_path = DATAFILES_PATH / 'lotka-volterra.csv'
25+
fit = cmdstanpy.from_csv(path=csvfiles_path)
26+
fit2 = copy.deepcopy(fit)
27+
self.assertSequenceEqual(
28+
fit.stan_variables().keys(), fit2.stan_variables().keys()
29+
)
30+
31+
32+
class OptimizeCompliance(unittest.TestCase):
33+
def test_optimize_pickle_ability(self):
34+
csvfiles_path = DATAFILES_PATH / 'optimize' / 'rosenbrock_mle.csv'
35+
fit = cmdstanpy.from_csv(path=csvfiles_path)
36+
keys = fit.stan_variables().keys()
37+
pickled = pickle.dumps(fit)
38+
del fit
39+
unpickled = pickle.loads(pickled)
40+
self.assertSequenceEqual(keys, unpickled.stan_variables().keys())
41+
42+
def test_optimize_copy_ability(self):
43+
csvfiles_path = DATAFILES_PATH / 'optimize' / 'rosenbrock_mle.csv'
44+
fit = cmdstanpy.from_csv(path=csvfiles_path)
45+
fit2 = copy.deepcopy(fit)
46+
self.assertSequenceEqual(
47+
fit.stan_variables().keys(), fit2.stan_variables().keys()
48+
)
49+
50+
51+
class VariationalCompliance(unittest.TestCase):
52+
def test_variational_pickle_ability(self):
53+
csvfiles_path = DATAFILES_PATH / 'variational'
54+
fit = cmdstanpy.from_csv(path=csvfiles_path)
55+
keys = fit.stan_variables().keys()
56+
pickled = pickle.dumps(fit)
57+
del fit
58+
unpickled = pickle.loads(pickled)
59+
self.assertSequenceEqual(keys, unpickled.stan_variables().keys())
60+
61+
def test_variational_copy_ability(self):
62+
csvfiles_path = DATAFILES_PATH / 'variational'
63+
fit = cmdstanpy.from_csv(path=csvfiles_path)
64+
fit2 = copy.deepcopy(fit)
65+
self.assertSequenceEqual(
66+
fit.stan_variables().keys(), fit2.stan_variables().keys()
67+
)

0 commit comments

Comments
 (0)