Skip to content

Commit 9c671a0

Browse files
committed
Reorganize new tests
1 parent 5d62bb5 commit 9c671a0

File tree

5 files changed

+83
-77
lines changed

5 files changed

+83
-77
lines changed

cmdstanpy/stanfit/__init__.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030

3131
def from_csv(
32-
path: Union[str, List[str], None] = None, method: Optional[str] = None
32+
path: Union[str, List[str], os.PathLike, None] = None,
33+
method: Optional[str] = None,
3334
) -> Union[CmdStanMCMC, CmdStanMLE, CmdStanVB, None]:
3435
"""
3536
Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run.
@@ -61,22 +62,22 @@ def from_csv(
6162
csvfiles = []
6263
if isinstance(path, list):
6364
csvfiles = path
64-
elif isinstance(path, str):
65-
if '*' in path:
66-
splits = os.path.split(path)
67-
if splits[0] is not None:
68-
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
69-
raise ValueError(
70-
'Invalid path specification, {} '
71-
' unknown directory: {}'.format(path, splits[0])
72-
)
73-
csvfiles = glob.glob(path)
74-
elif os.path.exists(path) and os.path.isdir(path):
65+
elif isinstance(path, str) and '*' in path:
66+
splits = os.path.split(path)
67+
if splits[0] is not None:
68+
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
69+
raise ValueError(
70+
'Invalid path specification, {} '
71+
' unknown directory: {}'.format(path, splits[0])
72+
)
73+
csvfiles = glob.glob(path)
74+
elif isinstance(path, (str, os.PathLike)):
75+
if os.path.exists(path) and os.path.isdir(path):
7576
for file in os.listdir(path):
76-
if file.endswith(".csv"):
77+
if os.path.splitext(path)[1] == ".csv":
7778
csvfiles.append(os.path.join(path, file))
7879
elif os.path.exists(path):
79-
csvfiles.append(path)
80+
csvfiles.append(str(path))
8081
else:
8182
raise ValueError('Invalid path specification: {}'.format(path))
8283
else:
@@ -85,7 +86,7 @@ def from_csv(
8586
if len(csvfiles) == 0:
8687
raise ValueError('No CSV files found in directory {}'.format(path))
8788
for file in csvfiles:
88-
if not (os.path.exists(file) and file.endswith('.csv')):
89+
if not (os.path.exists(file) and os.path.splitext(file)[1] == ".csv"):
8990
raise ValueError(
9091
'Bad CSV file path spec,'
9192
' includes non-csv file: {}'.format(file)

test/test_compliance.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Testing for things like pickleability, deep copying"""
2+
3+
import copy
4+
import pathlib
5+
import pickle
6+
import unittest
7+
8+
import cmdstanpy
9+
10+
DATAFILES_PATH = pathlib.Path(__file__).parent.resolve() / 'data'
11+
12+
13+
class SampleCompliance(unittest.TestCase):
14+
def test_sample_pickle_ability(self):
15+
csvfiles_path = DATAFILES_PATH / 'lotka-volterra.csv'
16+
fit = cmdstanpy.from_csv(path=csvfiles_path)
17+
pickled = pickle.dumps(fit)
18+
unpickled = pickle.loads(pickled)
19+
self.assertSequenceEqual(
20+
fit.stan_variables().keys(), unpickled.stan_variables().keys()
21+
)
22+
23+
def test_sample_copy_ability(self):
24+
csvfiles_path = DATAFILES_PATH / 'lotka-volterra.csv'
25+
fit = cmdstanpy.from_csv(path=csvfiles_path)
26+
fit2 = copy.deepcopy(fit)
27+
self.assertSequenceEqual(
28+
fit.stan_variables().keys(), fit2.stan_variables().keys()
29+
)
30+
31+
32+
class OptimizeCompliance(unittest.TestCase):
33+
def test_optimize_pickle_ability(self):
34+
csvfiles_path = DATAFILES_PATH / 'optimize' / 'rosenbrock_mle.csv'
35+
fit = cmdstanpy.from_csv(path=csvfiles_path)
36+
pickled = pickle.dumps(fit)
37+
unpickled = pickle.loads(pickled)
38+
self.assertSequenceEqual(
39+
fit.stan_variables().keys(), unpickled.stan_variables().keys()
40+
)
41+
42+
def test_optimize_copy_ability(self):
43+
csvfiles_path = DATAFILES_PATH / 'optimize' / 'rosenbrock_mle.csv'
44+
fit = cmdstanpy.from_csv(path=csvfiles_path)
45+
fit2 = copy.deepcopy(fit)
46+
self.assertSequenceEqual(
47+
fit.stan_variables().keys(), fit2.stan_variables().keys()
48+
)
49+
50+
51+
class VariationalCompliance(unittest.TestCase):
52+
def test_variational_pickle_ability(self):
53+
csvfiles_path = DATAFILES_PATH / 'variational' / 'eta_big_output.csv'
54+
fit = cmdstanpy.from_csv(path=csvfiles_path)
55+
pickled = pickle.dumps(fit)
56+
unpickled = pickle.loads(pickled)
57+
self.assertSequenceEqual(
58+
fit.stan_variables().keys(), unpickled.stan_variables().keys()
59+
)
60+
61+
def test_variational_copy_ability(self):
62+
csvfiles_path = DATAFILES_PATH / 'variational' / 'eta_big_output.csv'
63+
fit = cmdstanpy.from_csv(path=csvfiles_path)
64+
fit2 = copy.deepcopy(fit)
65+
self.assertSequenceEqual(
66+
fit.stan_variables().keys(), fit2.stan_variables().keys()
67+
)

test/test_optimize.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -627,28 +627,6 @@ def test_attrs(self):
627627
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
628628
dummy = fit.c
629629

630-
def test_pickle_ability(self):
631-
"""Ensure fit objects are pickle-able and copy-able"""
632-
633-
import pickle
634-
635-
csvfiles_path = os.path.join(
636-
DATAFILES_PATH, 'optimize', 'rosenbrock_mle.csv'
637-
)
638-
fit = from_csv(path=csvfiles_path)
639-
pickled = pickle.dumps(fit)
640-
unpickled = pickle.loads(pickled)
641-
self.assertSequenceEqual(
642-
fit.stan_variables().keys(), unpickled.stan_variables().keys()
643-
)
644-
645-
import copy
646-
647-
fit2 = copy.deepcopy(fit)
648-
self.assertSequenceEqual(
649-
fit.stan_variables().keys(), fit2.stan_variables().keys()
650-
)
651-
652630

653631
if __name__ == '__main__':
654632
unittest.main()

test/test_sample.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,26 +1881,6 @@ def test_diagnostics(self):
18811881
self.assertEqual(fit.max_treedepths, None)
18821882
self.assertEqual(fit.divergences, None)
18831883

1884-
def test_pickle_ability(self):
1885-
"""Ensure fit objects are pickle-able and copy-able"""
1886-
1887-
import pickle
1888-
1889-
csvfiles_path = os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')
1890-
fit = from_csv(path=csvfiles_path)
1891-
pickled = pickle.dumps(fit)
1892-
unpickled = pickle.loads(pickled)
1893-
self.assertSequenceEqual(
1894-
fit.stan_variables().keys(), unpickled.stan_variables().keys()
1895-
)
1896-
1897-
import copy
1898-
1899-
fit2 = copy.deepcopy(fit)
1900-
self.assertSequenceEqual(
1901-
fit.stan_variables().keys(), fit2.stan_variables().keys()
1902-
)
1903-
19041884

19051885
if __name__ == '__main__':
19061886
unittest.main()

test/test_variational.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -284,26 +284,6 @@ def test_attrs(self):
284284
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
285285
dummy = fit.c
286286

287-
def test_pickle_ability(self):
288-
"""Ensure fit objects are pickle-able and copy-able"""
289-
290-
import pickle
291-
292-
csvfiles_path = os.path.join(DATAFILES_PATH, 'variational')
293-
fit = from_csv(path=csvfiles_path)
294-
pickled = pickle.dumps(fit)
295-
unpickled = pickle.loads(pickled)
296-
self.assertSequenceEqual(
297-
fit.stan_variables().keys(), unpickled.stan_variables().keys()
298-
)
299-
300-
import copy
301-
302-
fit2 = copy.deepcopy(fit)
303-
self.assertSequenceEqual(
304-
fit.stan_variables().keys(), fit2.stan_variables().keys()
305-
)
306-
307287

308288
if __name__ == '__main__':
309289
unittest.main()

0 commit comments

Comments
 (0)