Skip to content

Commit 5967ab9

Browse files
authored
Merge pull request #838 from amas0/enable-save-cmdstan-config
Enable save_cmdstan_config options for all CmdStan calls
2 parents b8bcfab + 5ac4dc3 commit 5967ab9

File tree

5 files changed

+75
-17
lines changed

5 files changed

+75
-17
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ def compose_command(
866866
cmd.append(f'init={self.inits[idx]}')
867867
cmd.append('output')
868868
cmd.append(f'file={csv_file}')
869+
cmd.append('save_cmdstan_config=1')
869870
if diagnostic_file:
870871
cmd.append(f'diagnostic_file={diagnostic_file}')
871872
if profile_file:

cmdstanpy/stanfit/runset.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,29 @@ def __init__(
5656
)
5757
self._stdout_files, self._profile_files = [], []
5858
self._csv_files, self._diagnostic_files = [], []
59+
self._config_files = []
5960

6061
# per-process output files
6162
if one_process_per_chain and chains > 1:
6263
self._stdout_files = [
6364
self.gen_file_name(".txt", extra="stdout", id=id)
6465
for id in self._chain_ids
6566
]
67+
self._config_files = [
68+
os.path.join(
69+
self._outdir, f"{self._base_outfile}_{id}_config.json"
70+
)
71+
for id in self._chain_ids
72+
]
73+
6674
if args.save_profile:
6775
self._profile_files = [
6876
self.gen_file_name(".csv", extra="profile", id=id)
6977
for id in self._chain_ids
7078
]
7179
else:
7280
self._stdout_files = [self.gen_file_name(".txt", extra="stdout")]
81+
self._config_files = [self.gen_file_name(".json", extra="config")]
7382
if args.save_profile:
7483
self._profile_files = [
7584
self.gen_file_name(".csv", extra="profile")
@@ -93,25 +102,21 @@ def __init__(
93102
]
94103

95104
def __repr__(self) -> str:
96-
repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(
97-
self._chains, self._chain_ids, self._num_procs
98-
)
99-
repr = '{}\n cmd (chain 1):\n\t{}'.format(repr, self.cmd(0))
100-
repr = '{}\n retcodes={}'.format(repr, self._retcodes)
101-
repr = f'{repr}\n per-chain output files (showing chain 1 only):'
102-
repr = '{}\n csv_file:\n\t{}'.format(repr, self._csv_files[0])
105+
lines = [
106+
f"RunSet: chains={self._chains}, chain_ids={self._chain_ids}, "
107+
f"num_processes={self._num_procs}",
108+
f" cmd (chain 1):\n\t{self.cmd(0)}",
109+
f" retcodes={self._retcodes}",
110+
" per-chain output files (showing chain 1 only):",
111+
f" csv_file:\n\t{self._csv_files[0] if self._csv_files else ''}",
112+
]
103113
if self._args.save_latent_dynamics:
104-
repr = '{}\n diagnostics_file:\n\t{}'.format(
105-
repr, self._diagnostic_files[0]
106-
)
114+
lines.append(f" diagnostics_file:\n\t{self._diagnostic_files[0]}")
107115
if self._args.save_profile:
108-
repr = '{}\n profile_file:\n\t{}'.format(
109-
repr, self._profile_files[0]
110-
)
111-
repr = '{}\n console_msgs (if any):\n\t{}'.format(
112-
repr, self._stdout_files[0]
113-
)
114-
return repr
116+
lines.append(f" profile_file:\n\t{self._profile_files[0]}")
117+
lines.append(f" console_msgs (if any):\n\t{self._stdout_files[0]}")
118+
lines.append(f" config_files:\n\t{self._config_files[0]}")
119+
return '\n'.join(lines)
115120

116121
@property
117122
def model(self) -> str:
@@ -196,6 +201,13 @@ def stdout_files(self) -> list[str]:
196201
"""
197202
return self._stdout_files
198203

204+
@property
205+
def config_files(self) -> list[str]:
206+
"""
207+
List of paths to CmdStan config json files.
208+
"""
209+
return self._config_files
210+
199211
def _check_retcodes(self) -> bool:
200212
"""Returns ``True`` when all chains have retcode 0."""
201213
return all(retcode == 0 for retcode in self._retcodes)

test/test_cmdstan_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,3 +808,15 @@ def test_args_pathfinder_bad(arg: str, require_int: bool) -> None:
808808
args = PathfinderArgs(**{arg: 1.1}) # type: ignore
809809
with pytest.raises(ValueError):
810810
args.validate()
811+
812+
813+
def test_save_cmdstan_config() -> None:
814+
sampler_args = SamplerArgs()
815+
cmdstan_args = CmdStanArgs(
816+
model_name='bernoulli',
817+
model_exe='',
818+
chain_ids=[1, 2, 3, 4],
819+
method_args=sampler_args,
820+
)
821+
command = cmdstan_args.compose_command(0, csv_file="foo")
822+
assert "save_cmdstan_config=1" in command

test/test_runset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_check_repr() -> None:
3030
assert 'csv_file' in repr(runset)
3131
assert 'console_msgs' in repr(runset)
3232
assert 'diagnostics_file' not in repr(runset)
33+
assert 'config_file' in repr(runset)
3334

3435

3536
def test_check_retcodes() -> None:
@@ -106,6 +107,11 @@ def test_output_filenames_one_proc_per_chain() -> None:
106107
stdout_file.endswith(f"_stdout_{id}.txt")
107108
for id, stdout_file in zip(chain_ids, runset.stdout_files)
108109
)
110+
assert len(runset.config_files) == len(chain_ids)
111+
assert all(
112+
config_file.endswith(f"_{id}_config.json")
113+
for id, config_file in zip(chain_ids, runset.config_files)
114+
)
109115

110116
cmdstan_args_other_files = CmdStanArgs(
111117
model_name='bernoulli',
@@ -153,6 +159,8 @@ def test_output_filenames_threading() -> None:
153159
)
154160
assert len(runset.stdout_files) == 1
155161
assert runset.stdout_files[0].endswith("_stdout.txt")
162+
assert len(runset.config_files) == 1
163+
assert runset.config_files[0].endswith("_config.json")
156164

157165
cmdstan_args_other_files = CmdStanArgs(
158166
model_name='bernoulli',
@@ -198,6 +206,7 @@ def test_output_filenames_single_chain() -> None:
198206
runset = RunSet(args=cmdstan_args, chains=1, one_process_per_chain=True)
199207
base_file = runset._base_outfile
200208
assert runset.stdout_files[0].endswith(f"{base_file}_stdout.txt")
209+
assert runset.config_files[0].endswith(f"{base_file}_config.json")
201210

202211
cmdstan_args_other_files = CmdStanArgs(
203212
model_name='bernoulli',

test/test_sample.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,3 +2204,27 @@ def test_no_output_draws() -> None:
22042204
mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2)
22052205
draws = mcmc.draws()
22062206
assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names))))
2207+
2208+
2209+
def test_config_output() -> None:
2210+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
2211+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
2212+
model = CmdStanModel(stan_file=stan)
2213+
fit = model.sample(
2214+
data=jdata,
2215+
chains=2,
2216+
seed=12345,
2217+
iter_warmup=100,
2218+
iter_sampling=200,
2219+
)
2220+
assert all(os.path.exists(cf) for cf in fit.runset.config_files)
2221+
2222+
# Config file naming differs when only a single chain is output
2223+
fit_one_chain = model.sample(
2224+
data=jdata,
2225+
chains=1,
2226+
seed=12345,
2227+
iter_warmup=100,
2228+
iter_sampling=200,
2229+
)
2230+
assert all(os.path.exists(cf) for cf in fit_one_chain.runset.config_files)

0 commit comments

Comments
 (0)