Skip to content

Commit 92b803c

Browse files
committed
Add single path pathfinder outputs to runset
1 parent 689e0ac commit 92b803c

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

cmdstanpy/stanfit/runset.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from time import time
1212

1313
from cmdstanpy import _TMPDIR
14-
from cmdstanpy.cmdstan_args import CmdStanArgs, Method
14+
from cmdstanpy.cmdstan_args import CmdStanArgs, Method, PathfinderArgs
1515
from cmdstanpy.utils import get_logger
1616

1717

@@ -57,6 +57,8 @@ def __init__(
5757
self._stdout_files, self._profile_files = [], []
5858
self._csv_files, self._diagnostic_files = [], []
5959
self._config_files = []
60+
self._single_path_csv_files: list[str] = []
61+
self._single_path_json_files: list[str] = []
6062

6163
# per-process output files
6264
if one_process_per_chain and chains > 1:
@@ -101,6 +103,9 @@ def __init__(
101103
for id in self._chain_ids
102104
]
103105

106+
if args.method == Method.PATHFINDER:
107+
self.populate_pathfinder_single_path_files()
108+
104109
def __repr__(self) -> str:
105110
lines = [
106111
f"RunSet: chains={self._chains}, chain_ids={self._chain_ids}, "
@@ -222,6 +227,18 @@ def profile_files(self) -> list[str]:
222227
"""List of paths to CmdStan profiler files."""
223228
return self._profile_files
224229

230+
@property
231+
def single_path_csv_files(self) -> list[str]:
232+
"""List of paths to single-path Pathfinder output CSV files.
233+
Only populated when method is Pathfinder and save_single_paths=True"""
234+
return self._single_path_csv_files
235+
236+
@property
237+
def single_path_json_files(self) -> list[str]:
238+
"""List of paths to single-path Pathfinder output ELBO JSON files.
239+
Only populated when method is Pathfinder and save_single_paths=True"""
240+
return self._single_path_json_files
241+
225242
def gen_file_name(
226243
self, suffix: str, *, extra: str = "", id: int | None = None
227244
) -> str:
@@ -317,3 +334,25 @@ def raise_for_timeouts(self) -> None:
317334
f"{sum(self._timeout_flags)} of {self.num_procs} "
318335
"processes timed out"
319336
)
337+
338+
def populate_pathfinder_single_path_files(self) -> None:
339+
if not isinstance(self._args.method_args, PathfinderArgs):
340+
return
341+
if self._args.method_args.save_single_paths:
342+
num_paths = self._args.method_args.num_paths
343+
if num_paths > 1:
344+
self._single_path_csv_files = [
345+
self.gen_file_name(".csv", extra="path", id=id)
346+
for id in range(1, num_paths + 1)
347+
]
348+
self._single_path_json_files = [
349+
self.gen_file_name(".json", extra="path", id=id)
350+
for id in range(1, num_paths + 1)
351+
]
352+
else: # num_paths == 1
353+
self._single_path_csv_files = [
354+
self.gen_file_name(".csv", extra="path")
355+
]
356+
self._single_path_json_files = [
357+
self.gen_file_name(".csv", extra="json")
358+
]

0 commit comments

Comments
 (0)