Skip to content

Commit 747b0c1

Browse files
committed
Add single chain output file testing
1 parent 3d397e7 commit 747b0c1

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

test/test_runset.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,57 @@ def test_output_filenames_threading() -> None:
176176
assert runset_other_files.profile_files[0].endswith("_profile.csv")
177177

178178

179+
def test_output_filenames_single_chain() -> None:
180+
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
181+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
182+
sampler_args = SamplerArgs()
183+
chain_ids = [1]
184+
cmdstan_args = CmdStanArgs(
185+
model_name='bernoulli',
186+
model_exe=exe,
187+
chain_ids=chain_ids,
188+
data=jdata,
189+
method_args=sampler_args,
190+
)
191+
runset = RunSet(args=cmdstan_args, chains=1, one_process_per_chain=False)
192+
base_file = runset._base_outfile
193+
assert len(runset.csv_files) == 1
194+
assert len(runset.stdout_files) == 1
195+
assert runset.csv_files[0].endswith(f"{base_file}.csv")
196+
assert runset.stdout_files[0].endswith(f"{base_file}_stdout.txt")
197+
198+
runset = RunSet(args=cmdstan_args, chains=1, one_process_per_chain=True)
199+
base_file = runset._base_outfile
200+
assert runset.stdout_files[0].endswith(f"{base_file}_stdout.txt")
201+
202+
cmdstan_args_other_files = CmdStanArgs(
203+
model_name='bernoulli',
204+
model_exe=exe,
205+
chain_ids=chain_ids,
206+
data=jdata,
207+
method_args=sampler_args,
208+
save_latent_dynamics=True,
209+
save_profile=True,
210+
)
211+
runset_other_files = RunSet(
212+
args=cmdstan_args_other_files, chains=1, one_process_per_chain=False
213+
)
214+
assert len(runset_other_files.diagnostic_files) == 1
215+
assert runset_other_files.diagnostic_files[0].endswith("_diagnostic.csv")
216+
217+
assert len(runset_other_files.profile_files) == 1
218+
assert runset_other_files.profile_files[0].endswith("_profile.csv")
219+
220+
runset_other_files = RunSet(
221+
args=cmdstan_args_other_files, chains=1, one_process_per_chain=True
222+
)
223+
assert len(runset_other_files.diagnostic_files) == 1
224+
assert runset_other_files.diagnostic_files[0].endswith("_diagnostic.csv")
225+
226+
assert len(runset_other_files.profile_files) == 1
227+
assert runset_other_files.profile_files[0].endswith("_profile.csv")
228+
229+
179230
def test_commands() -> None:
180231
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
181232
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

0 commit comments

Comments
 (0)