Skip to content

Commit 45d2163

Browse files
committed
Add RunSet tests for single path Pathfinder files
1 parent 4169cee commit 45d2163

File tree

1 file changed

+56
-1
lines changed

1 file changed

+56
-1
lines changed

test/test_runset.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
from cmdstanpy import _TMPDIR
6-
from cmdstanpy.cmdstan_args import CmdStanArgs, SamplerArgs
6+
from cmdstanpy.cmdstan_args import CmdStanArgs, PathfinderArgs, SamplerArgs
77
from cmdstanpy.stanfit import RunSet
88
from cmdstanpy.utils import EXTENSION
99

@@ -299,3 +299,58 @@ def test_chain_ids() -> None:
299299
assert '_11.csv' in runset._csv_files[0]
300300
assert 'id=14' in runset.cmd(3)
301301
assert '_14.csv' in runset._csv_files[3]
302+
303+
304+
def test_output_filenames_pathfinder_single_paths() -> None:
305+
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
306+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
307+
sampler_args = PathfinderArgs(num_paths=4, save_single_paths=True)
308+
chain_ids = [1]
309+
cmdstan_args = CmdStanArgs(
310+
model_name='bernoulli',
311+
model_exe=exe,
312+
chain_ids=chain_ids,
313+
data=jdata,
314+
method_args=sampler_args,
315+
)
316+
runset = RunSet(args=cmdstan_args)
317+
assert len(runset.single_path_csv_files) == 4
318+
assert len(runset.single_path_json_files) == 4
319+
320+
assert all(
321+
csv_file.endswith(f"_path_{id}.csv")
322+
for id, csv_file in zip(range(1, 5), runset.single_path_csv_files)
323+
)
324+
assert all(
325+
json_file.endswith(f"_path_{id}.json")
326+
for id, json_file in zip(range(1, 5), runset.single_path_json_files)
327+
)
328+
329+
sampler_args = PathfinderArgs(num_paths=1, save_single_paths=True)
330+
cmdstan_args = CmdStanArgs(
331+
model_name='bernoulli',
332+
model_exe=exe,
333+
chain_ids=chain_ids,
334+
data=jdata,
335+
method_args=sampler_args,
336+
)
337+
runset = RunSet(args=cmdstan_args)
338+
339+
assert len(runset.single_path_csv_files) == 1
340+
assert len(runset.single_path_json_files) == 1
341+
342+
assert runset.single_path_csv_files[0].endswith(".csv")
343+
assert runset.single_path_json_files[0].endswith(".json")
344+
345+
sampler_args = PathfinderArgs(num_paths=1, save_single_paths=False)
346+
cmdstan_args = CmdStanArgs(
347+
model_name='bernoulli',
348+
model_exe=exe,
349+
chain_ids=chain_ids,
350+
data=jdata,
351+
method_args=sampler_args,
352+
)
353+
runset = RunSet(args=cmdstan_args)
354+
355+
assert len(runset.single_path_csv_files) == 0
356+
assert len(runset.single_path_json_files) == 0

0 commit comments

Comments
 (0)