Skip to content

Commit 8eb371b

Browse files
committed
Add single path output test
1 parent fe37daa commit 8eb371b

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test/test_pathfinder.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import contextlib
6+
import os
67
from io import StringIO
78
from pathlib import Path
89

@@ -193,3 +194,24 @@ def test_pathfinder_threads() -> None:
193194
)
194195
pathfinder = bern_model.pathfinder(data=jdata, num_threads=4)
195196
assert pathfinder.draws().shape == (1000, 4)
197+
198+
199+
def test_pathfinder_single_path_output() -> None:
200+
201+
stan = DATAFILES_PATH / 'bernoulli.stan'
202+
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
203+
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
204+
205+
fit = bern_model.pathfinder(data=jdata, num_paths=4, save_single_paths=True)
206+
assert len(fit.runset.single_path_csv_files) == 4
207+
assert len(fit.runset.single_path_json_files) == 4
208+
209+
assert all(os.path.exists(f) for f in fit.runset.single_path_csv_files)
210+
assert all(os.path.exists(f) for f in fit.runset.single_path_json_files)
211+
212+
fit = bern_model.pathfinder(data=jdata, num_paths=1, save_single_paths=True)
213+
assert len(fit.runset.single_path_csv_files) == 1
214+
assert len(fit.runset.single_path_json_files) == 1
215+
216+
assert all(os.path.exists(f) for f in fit.runset.single_path_csv_files)
217+
assert all(os.path.exists(f) for f in fit.runset.single_path_json_files)

0 commit comments

Comments
 (0)