Skip to content

Commit 02e627d

Browse files
committed
Standardize output file naming
1 parent 58ce09e commit 02e627d

File tree

5 files changed

+54
-51
lines changed

5 files changed

+54
-51
lines changed

cmdstanpy/stanfit/runset.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -38,62 +38,59 @@ def __init__(
3838
self._args = args
3939
self._chains = chains
4040
self._one_process_per_chain = one_process_per_chain
41-
if one_process_per_chain:
42-
self._num_procs = chains
43-
else:
44-
self._num_procs = 1
41+
self._num_procs = chains if one_process_per_chain else 1
4542
self._retcodes = [-1 for _ in range(self._num_procs)]
4643
self._timeout_flags = [False for _ in range(self._num_procs)]
4744
if chain_ids is None:
4845
chain_ids = [i + 1 for i in range(chains)]
4946
self._chain_ids = chain_ids
5047

5148
if args.output_dir is not None:
52-
self._output_dir = args.output_dir
53-
else:
54-
# make a per-run subdirectory of our master temp directory
55-
self._output_dir = tempfile.mkdtemp(
56-
prefix=args.model_name, dir=_TMPDIR
57-
)
49+
self._outdir = args.output_dir
50+
else: # make a per-run subdirectory of our master temp directory
51+
self._outdir = tempfile.mkdtemp(prefix=args.model_name, dir=_TMPDIR)
5852

5953
# output files prefix: ``<model_name>-<YYYYMMDDHHMM>_<chain_id>``
6054
self._base_outfile = (
6155
f'{args.model_name}-{datetime.now().strftime(time_fmt)}'
6256
)
63-
# per-process outputs
64-
self._stdout_files = [''] * self._num_procs
65-
self._profile_files = [''] * self._num_procs # optional
57+
self._stdout_files, self._profile_files = [], []
58+
self._csv_files, self._diagnostic_files = [], []
59+
60+
# per-process output files
6661
if one_process_per_chain:
67-
for i in range(chains):
68-
self._stdout_files[i] = self.file_path("-stdout.txt", id=i)
69-
if args.save_profile:
70-
self._profile_files[i] = self.file_path(
71-
".csv", extra="-profile", id=chain_ids[i]
72-
)
62+
self._stdout_files = [
63+
self.gen_file_name(".txt", extra="stdout", id=id)
64+
for id in self._chain_ids
65+
]
66+
if args.save_profile:
67+
self._profile_files = [
68+
self.gen_file_name(".csv", extra="profile", id=id)
69+
for id in self._chain_ids
70+
]
7371
else:
74-
self._stdout_files[0] = self.file_path("-stdout.txt")
72+
self._stdout_files = [self.gen_file_name(".txt", extra="stdout")]
7573
if args.save_profile:
76-
self._profile_files[0] = self.file_path(
77-
".csv", extra="-profile"
78-
)
74+
self._profile_files = [
75+
self.gen_file_name(".csv", extra="profile")
76+
]
7977

8078
# per-chain output files
81-
self._csv_files: list[str] = [''] * chains
82-
self._diagnostic_files = [''] * chains # optional
83-
8479
if chains == 1:
85-
self._csv_files[0] = self.file_path(".csv")
80+
self._csv_files = [self.gen_file_name(".csv")]
8681
if args.save_latent_dynamics:
87-
self._diagnostic_files[0] = self.file_path(
88-
".csv", extra="-diagnostic"
89-
)
82+
self._diagnostic_files = [
83+
self.gen_file_name(".csv", extra="diagnostic")
84+
]
9085
else:
91-
for i in range(chains):
92-
self._csv_files[i] = self.file_path(".csv", id=chain_ids[i])
93-
if args.save_latent_dynamics:
94-
self._diagnostic_files[i] = self.file_path(
95-
".csv", extra="-diagnostic", id=chain_ids[i]
96-
)
86+
self._csv_files = [
87+
self.gen_file_name(".csv", id=id) for id in self._chain_ids
88+
]
89+
if args.save_latent_dynamics:
90+
self._diagnostic_files = [
91+
self.gen_file_name(".csv", extra="diagnostic", id=id)
92+
for id in self._chain_ids
93+
]
9794

9895
def __repr__(self) -> str:
9996
repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(
@@ -173,14 +170,14 @@ def cmd(self, idx: int) -> list[str]:
173170
else:
174171
return self._args.compose_command(
175172
idx,
176-
csv_file=self.file_path('.csv'),
173+
csv_file=self.gen_file_name('.csv'),
177174
diagnostic_file=(
178-
self.file_path(".csv", extra="-diagnostic")
175+
self.gen_file_name(".csv", extra="diagnostic")
179176
if self._args.save_latent_dynamics
180177
else None
181178
),
182179
profile_file=(
183-
self.file_path(".csv", extra="-profile")
180+
self.gen_file_name(".csv", extra="profile")
184181
if self._args.save_profile
185182
else None
186183
),
@@ -216,16 +213,22 @@ def profile_files(self) -> list[str]:
216213
"""List of paths to CmdStan profiler files."""
217214
return self._profile_files
218215

219-
# pylint: disable=invalid-name
220-
def file_path(
216+
def gen_file_name(
221217
self, suffix: str, *, extra: str = "", id: int | None = None
222218
) -> str:
223-
if id is not None:
224-
suffix = f"_{id}{suffix}"
225-
file = os.path.join(
226-
self._output_dir, f"{self._base_outfile}{extra}{suffix}"
227-
)
228-
return file
219+
"""Generate a standard file name according to CmdStan output pattern"""
220+
match (id, extra):
221+
case (None, ""):
222+
file = f"{self._base_outfile}{suffix}"
223+
case (None, extra) if extra != "":
224+
file = f"{self._base_outfile}_{extra}{suffix}"
225+
case (id, ""):
226+
file = f"{self._base_outfile}_{id}{suffix}"
227+
case (id, extra) if extra != "":
228+
file = f"{self._base_outfile}_{id}_{extra}{suffix}"
229+
case _:
230+
raise ValueError("Cannot construct valid file name")
231+
return os.path.join(self._outdir, file)
229232

230233
def _retcode(self, idx: int) -> int:
231234
"""Get retcode for process[idx]."""

test/test_generate_quantities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def test_serialization() -> None:
533533
fit1 = model.generate_quantities(data=jdata, previous_fit=fit_sampling)
534534

535535
dumped = pickle.dumps(fit1)
536-
shutil.rmtree(fit1.runset._output_dir)
536+
shutil.rmtree(fit1.runset._outdir)
537537
fit2: CmdStanGQ[CmdStanMCMC] = pickle.loads(dumped)
538538
variables1 = fit1.stan_variables()
539539
variables2 = fit2.stan_variables()

test/test_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ def test_serialization() -> None:
664664
history_size=5,
665665
)
666666
dumped = pickle.dumps(mle1)
667-
shutil.rmtree(mle1.runset._output_dir)
667+
shutil.rmtree(mle1.runset._outdir)
668668
mle2: CmdStanMLE = pickle.loads(dumped)
669669
np.testing.assert_array_equal(
670670
mle1.optimized_params_np, mle2.optimized_params_np

test/test_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2135,7 +2135,7 @@ def test_serialization(stanfile: str = 'bernoulli.stan') -> None:
21352135
)
21362136
# Dump the result (which assembles draws) and delete the source files.
21372137
dumped = pickle.dumps(bern_fit1)
2138-
shutil.rmtree(bern_fit1.runset._output_dir)
2138+
shutil.rmtree(bern_fit1.runset._outdir)
21392139
# Load the serialized result and compare results.
21402140
bern_fit2: CmdStanMCMC = pickle.loads(dumped)
21412141
variables1 = bern_fit1.stan_variables()

test/test_variational.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_serialization() -> None:
335335
model = CmdStanModel(stan_file=stan)
336336
variational1 = model.variational(algorithm='meanfield', seed=999999)
337337
dumped = pickle.dumps(variational1)
338-
shutil.rmtree(variational1.runset._output_dir)
338+
shutil.rmtree(variational1.runset._outdir)
339339
variational2: CmdStanVB = pickle.loads(dumped)
340340
np.testing.assert_array_equal(
341341
variational1.variational_sample, variational2.variational_sample

0 commit comments

Comments
 (0)