|
3 | 3 | import os |
4 | 4 |
|
5 | 5 | from cmdstanpy import _TMPDIR |
6 | | -from cmdstanpy.cmdstan_args import CmdStanArgs, SamplerArgs |
| 6 | +from cmdstanpy.cmdstan_args import CmdStanArgs, PathfinderArgs, SamplerArgs |
7 | 7 | from cmdstanpy.stanfit import RunSet |
8 | 8 | from cmdstanpy.utils import EXTENSION |
9 | 9 |
|
@@ -299,3 +299,58 @@ def test_chain_ids() -> None: |
299 | 299 | assert '_11.csv' in runset._csv_files[0] |
300 | 300 | assert 'id=14' in runset.cmd(3) |
301 | 301 | assert '_14.csv' in runset._csv_files[3] |
| 302 | + |
| 303 | + |
| 304 | +def test_output_filenames_pathfinder_single_paths() -> None: |
| 305 | + exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) |
| 306 | + jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') |
| 307 | + sampler_args = PathfinderArgs(num_paths=4, save_single_paths=True) |
| 308 | + chain_ids = [1] |
| 309 | + cmdstan_args = CmdStanArgs( |
| 310 | + model_name='bernoulli', |
| 311 | + model_exe=exe, |
| 312 | + chain_ids=chain_ids, |
| 313 | + data=jdata, |
| 314 | + method_args=sampler_args, |
| 315 | + ) |
| 316 | + runset = RunSet(args=cmdstan_args) |
| 317 | + assert len(runset.single_path_csv_files) == 4 |
| 318 | + assert len(runset.single_path_json_files) == 4 |
| 319 | + |
| 320 | + assert all( |
| 321 | + csv_file.endswith(f"_path_{id}.csv") |
| 322 | + for id, csv_file in zip(range(1, 5), runset.single_path_csv_files) |
| 323 | + ) |
| 324 | + assert all( |
| 325 | + json_file.endswith(f"_path_{id}.json") |
| 326 | + for id, json_file in zip(range(1, 5), runset.single_path_json_files) |
| 327 | + ) |
| 328 | + |
| 329 | + sampler_args = PathfinderArgs(num_paths=1, save_single_paths=True) |
| 330 | + cmdstan_args = CmdStanArgs( |
| 331 | + model_name='bernoulli', |
| 332 | + model_exe=exe, |
| 333 | + chain_ids=chain_ids, |
| 334 | + data=jdata, |
| 335 | + method_args=sampler_args, |
| 336 | + ) |
| 337 | + runset = RunSet(args=cmdstan_args) |
| 338 | + |
| 339 | + assert len(runset.single_path_csv_files) == 1 |
| 340 | + assert len(runset.single_path_json_files) == 1 |
| 341 | + |
| 342 | + assert runset.single_path_csv_files[0].endswith(".csv") |
| 343 | + assert runset.single_path_json_files[0].endswith(".json") |
| 344 | + |
| 345 | + sampler_args = PathfinderArgs(num_paths=1, save_single_paths=False) |
| 346 | + cmdstan_args = CmdStanArgs( |
| 347 | + model_name='bernoulli', |
| 348 | + model_exe=exe, |
| 349 | + chain_ids=chain_ids, |
| 350 | + data=jdata, |
| 351 | + method_args=sampler_args, |
| 352 | + ) |
| 353 | + runset = RunSet(args=cmdstan_args) |
| 354 | + |
| 355 | + assert len(runset.single_path_csv_files) == 0 |
| 356 | + assert len(runset.single_path_json_files) == 0 |
0 commit comments