Skip to content

Commit 5ac4dc3

Browse files
committed
Add config output tests
1 parent 10361f7 commit 5ac4dc3

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

test/test_cmdstan_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,3 +808,15 @@ def test_args_pathfinder_bad(arg: str, require_int: bool) -> None:
808808
args = PathfinderArgs(**{arg: 1.1}) # type: ignore
809809
with pytest.raises(ValueError):
810810
args.validate()
811+
812+
813+
def test_save_cmdstan_config() -> None:
814+
sampler_args = SamplerArgs()
815+
cmdstan_args = CmdStanArgs(
816+
model_name='bernoulli',
817+
model_exe='',
818+
chain_ids=[1, 2, 3, 4],
819+
method_args=sampler_args,
820+
)
821+
command = cmdstan_args.compose_command(0, csv_file="foo")
822+
assert "save_cmdstan_config=1" in command

test/test_sample.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,3 +2204,27 @@ def test_no_output_draws() -> None:
22042204
mcmc = model.sample(data=data, iter_sampling=0, save_warmup=False, chains=2)
22052205
draws = mcmc.draws()
22062206
assert np.array_equal(draws, np.empty((0, 2, len(mcmc.column_names))))
2207+
2208+
2209+
def test_config_output() -> None:
2210+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
2211+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
2212+
model = CmdStanModel(stan_file=stan)
2213+
fit = model.sample(
2214+
data=jdata,
2215+
chains=2,
2216+
seed=12345,
2217+
iter_warmup=100,
2218+
iter_sampling=200,
2219+
)
2220+
assert all(os.path.exists(cf) for cf in fit.runset.config_files)
2221+
2222+
# Config file naming differs when only a single chain is output
2223+
fit_one_chain = model.sample(
2224+
data=jdata,
2225+
chains=1,
2226+
seed=12345,
2227+
iter_warmup=100,
2228+
iter_sampling=200,
2229+
)
2230+
assert all(os.path.exists(cf) for cf in fit_one_chain.runset.config_files)

0 commit comments

Comments
 (0)