|
11 | 11 | from time import time |
12 | 12 |
|
13 | 13 | from cmdstanpy import _TMPDIR |
14 | | -from cmdstanpy.cmdstan_args import CmdStanArgs, Method |
| 14 | +from cmdstanpy.cmdstan_args import CmdStanArgs, Method, PathfinderArgs |
15 | 15 | from cmdstanpy.utils import get_logger |
16 | 16 |
|
17 | 17 |
|
@@ -57,6 +57,8 @@ def __init__( |
57 | 57 | self._stdout_files, self._profile_files = [], [] |
58 | 58 | self._csv_files, self._diagnostic_files = [], [] |
59 | 59 | self._config_files = [] |
| 60 | + self._single_path_csv_files: list[str] = [] |
| 61 | + self._single_path_json_files: list[str] = [] |
60 | 62 |
|
61 | 63 | # per-process output files |
62 | 64 | if one_process_per_chain and chains > 1: |
@@ -101,6 +103,9 @@ def __init__( |
101 | 103 | for id in self._chain_ids |
102 | 104 | ] |
103 | 105 |
|
| 106 | + if args.method == Method.PATHFINDER: |
| 107 | + self.populate_pathfinder_single_path_files() |
| 108 | + |
104 | 109 | def __repr__(self) -> str: |
105 | 110 | lines = [ |
106 | 111 | f"RunSet: chains={self._chains}, chain_ids={self._chain_ids}, " |
@@ -222,6 +227,18 @@ def profile_files(self) -> list[str]: |
222 | 227 | """List of paths to CmdStan profiler files.""" |
223 | 228 | return self._profile_files |
224 | 229 |
|
| 230 | + @property |
| 231 | + def single_path_csv_files(self) -> list[str]: |
| 232 | + """List of paths to single-path Pathfinder output CSV files. |
| 233 | + Only populated when method is Pathfinder and save_single_paths=True""" |
| 234 | + return self._single_path_csv_files |
| 235 | + |
| 236 | + @property |
| 237 | + def single_path_json_files(self) -> list[str]: |
| 238 | + """List of paths to single-path Pathfinder output ELBO JSON files. |
| 239 | + Only populated when method is Pathfinder and save_single_paths=True""" |
| 240 | + return self._single_path_json_files |
| 241 | + |
225 | 242 | def gen_file_name( |
226 | 243 | self, suffix: str, *, extra: str = "", id: int | None = None |
227 | 244 | ) -> str: |
@@ -317,3 +334,25 @@ def raise_for_timeouts(self) -> None: |
317 | 334 | f"{sum(self._timeout_flags)} of {self.num_procs} " |
318 | 335 | "processes timed out" |
319 | 336 | ) |
| 337 | + |
| 338 | + def populate_pathfinder_single_path_files(self) -> None: |
| 339 | + if not isinstance(self._args.method_args, PathfinderArgs): |
| 340 | + return |
| 341 | + if self._args.method_args.save_single_paths: |
| 342 | + num_paths = self._args.method_args.num_paths |
| 343 | + if num_paths > 1: |
| 344 | + self._single_path_csv_files = [ |
| 345 | + self.gen_file_name(".csv", extra="path", id=id) |
| 346 | + for id in range(1, num_paths + 1) |
| 347 | + ] |
| 348 | + self._single_path_json_files = [ |
| 349 | + self.gen_file_name(".json", extra="path", id=id) |
| 350 | + for id in range(1, num_paths + 1) |
| 351 | + ] |
| 352 | + else: # num_paths == 1 |
| 353 | + self._single_path_csv_files = [ |
| 354 | + self.gen_file_name(".csv", extra="path") |
| 355 | + ] |
| 356 | + self._single_path_json_files = [ |
| 357 | + self.gen_file_name(".csv", extra="json") |
| 358 | + ] |
0 commit comments