@@ -38,62 +38,59 @@ def __init__(
3838 self ._args = args
3939 self ._chains = chains
4040 self ._one_process_per_chain = one_process_per_chain
41- if one_process_per_chain :
42- self ._num_procs = chains
43- else :
44- self ._num_procs = 1
41+ self ._num_procs = chains if one_process_per_chain else 1
4542 self ._retcodes = [- 1 for _ in range (self ._num_procs )]
4643 self ._timeout_flags = [False for _ in range (self ._num_procs )]
4744 if chain_ids is None :
4845 chain_ids = [i + 1 for i in range (chains )]
4946 self ._chain_ids = chain_ids
5047
5148 if args .output_dir is not None :
52- self ._output_dir = args .output_dir
53- else :
54- # make a per-run subdirectory of our master temp directory
55- self ._output_dir = tempfile .mkdtemp (
56- prefix = args .model_name , dir = _TMPDIR
57- )
49+ self ._outdir = args .output_dir
50+ else : # make a per-run subdirectory of our master temp directory
51+ self ._outdir = tempfile .mkdtemp (prefix = args .model_name , dir = _TMPDIR )
5852
5953 # output files prefix: ``<model_name>-<YYYYMMDDHHMM>_<chain_id>``
6054 self ._base_outfile = (
6155 f'{ args .model_name } -{ datetime .now ().strftime (time_fmt )} '
6256 )
63- # per-process outputs
64- self ._stdout_files = ['' ] * self ._num_procs
65- self ._profile_files = ['' ] * self ._num_procs # optional
57+ self ._stdout_files , self ._profile_files = [], []
58+ self ._csv_files , self ._diagnostic_files = [], []
59+
60+ # per-process output files
6661 if one_process_per_chain :
67- for i in range (chains ):
68- self ._stdout_files [i ] = self .file_path ("-stdout.txt" , id = i )
69- if args .save_profile :
70- self ._profile_files [i ] = self .file_path (
71- ".csv" , extra = "-profile" , id = chain_ids [i ]
72- )
62+ self ._stdout_files = [
63+ self .gen_file_name (".txt" , extra = "stdout" , id = id )
64+ for id in self ._chain_ids
65+ ]
66+ if args .save_profile :
67+ self ._profile_files = [
68+ self .gen_file_name (".csv" , extra = "profile" , id = id )
69+ for id in self ._chain_ids
70+ ]
7371 else :
74- self ._stdout_files [ 0 ] = self .file_path ( "-stdout .txt")
72+ self ._stdout_files = [ self .gen_file_name ( " .txt", extra = "stdout" )]
7573 if args .save_profile :
76- self ._profile_files [ 0 ] = self . file_path (
77- ".csv" , extra = "- profile"
78- )
74+ self ._profile_files = [
75+ self . gen_file_name ( ".csv" , extra = "profile" )
76+ ]
7977
8078 # per-chain output files
81- self ._csv_files : list [str ] = ['' ] * chains
82- self ._diagnostic_files = ['' ] * chains # optional
83-
8479 if chains == 1 :
85- self ._csv_files [ 0 ] = self .file_path (".csv" )
80+ self ._csv_files = [ self .gen_file_name (".csv" )]
8681 if args .save_latent_dynamics :
87- self ._diagnostic_files [ 0 ] = self . file_path (
88- ".csv" , extra = "- diagnostic"
89- )
82+ self ._diagnostic_files = [
83+ self . gen_file_name ( ".csv" , extra = "diagnostic" )
84+ ]
9085 else :
91- for i in range (chains ):
92- self ._csv_files [i ] = self .file_path (".csv" , id = chain_ids [i ])
93- if args .save_latent_dynamics :
94- self ._diagnostic_files [i ] = self .file_path (
95- ".csv" , extra = "-diagnostic" , id = chain_ids [i ]
96- )
86+ self ._csv_files = [
87+ self .gen_file_name (".csv" , id = id ) for id in self ._chain_ids
88+ ]
89+ if args .save_latent_dynamics :
90+ self ._diagnostic_files = [
91+ self .gen_file_name (".csv" , extra = "diagnostic" , id = id )
92+ for id in self ._chain_ids
93+ ]
9794
9895 def __repr__ (self ) -> str :
9996 repr = 'RunSet: chains={}, chain_ids={}, num_processes={}' .format (
@@ -173,14 +170,14 @@ def cmd(self, idx: int) -> list[str]:
173170 else :
174171 return self ._args .compose_command (
175172 idx ,
176- csv_file = self .file_path ('.csv' ),
173+ csv_file = self .gen_file_name ('.csv' ),
177174 diagnostic_file = (
178- self .file_path (".csv" , extra = "- diagnostic" )
175+ self .gen_file_name (".csv" , extra = "diagnostic" )
179176 if self ._args .save_latent_dynamics
180177 else None
181178 ),
182179 profile_file = (
183- self .file_path (".csv" , extra = "- profile" )
180+ self .gen_file_name (".csv" , extra = "profile" )
184181 if self ._args .save_profile
185182 else None
186183 ),
@@ -216,16 +213,22 @@ def profile_files(self) -> list[str]:
216213 """List of paths to CmdStan profiler files."""
217214 return self ._profile_files
218215
219- # pylint: disable=invalid-name
220- def file_path (
216+ def gen_file_name (
221217 self , suffix : str , * , extra : str = "" , id : int | None = None
222218 ) -> str :
223- if id is not None :
224- suffix = f"_{ id } { suffix } "
225- file = os .path .join (
226- self ._output_dir , f"{ self ._base_outfile } { extra } { suffix } "
227- )
228- return file
219+ """Generate a standard file name according to CmdStan output pattern"""
220+ match (id , extra ):
221+ case (None , "" ):
222+ file = f"{ self ._base_outfile } { suffix } "
223+ case (None , extra ) if extra != "" :
224+ file = f"{ self ._base_outfile } _{ extra } { suffix } "
225+ case (id , "" ):
226+ file = f"{ self ._base_outfile } _{ id } { suffix } "
227+ case (id , extra ) if extra != "" :
228+ file = f"{ self ._base_outfile } _{ id } _{ extra } { suffix } "
229+ case _:
230+ raise ValueError ("Cannot construct valid file name" )
231+ return os .path .join (self ._outdir , file )
229232
230233 def _retcode (self , idx : int ) -> int :
231234 """Get retcode for process[idx]."""
0 commit comments