Skip to content

Commit ef7ca63

Browse files
committed
Add additional RunSet output file testing
1 parent a62e7d2 commit ef7ca63

File tree

1 file changed

+81
-5
lines changed

1 file changed

+81
-5
lines changed

test/test_runset.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_get_err_msgs() -> None:
8282
assert 'Exception: variable does not exist' in errs
8383

8484

85-
def test_output_filenames() -> None:
85+
def test_output_filenames_one_proc_per_chain() -> None:
8686
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
8787
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
8888
sampler_args = SamplerArgs()
@@ -94,10 +94,86 @@ def test_output_filenames() -> None:
9494
data=jdata,
9595
method_args=sampler_args,
9696
)
97-
runset = RunSet(args=cmdstan_args, chains=4)
98-
assert 'bernoulli-' in runset._csv_files[0]
99-
assert '_1.csv' in runset._csv_files[0]
100-
assert '_4.csv' in runset._csv_files[3]
97+
runset = RunSet(args=cmdstan_args, chains=4, one_process_per_chain=True)
98+
99+
assert all("bernoulli-" in csv_file for csv_file in runset.csv_files)
100+
assert all(
101+
csv_file.endswith(f"_{id}.csv")
102+
for id, csv_file in zip(chain_ids, runset.csv_files)
103+
)
104+
assert len(runset.stdout_files) == len(chain_ids)
105+
assert all(
106+
stdout_file.endswith(f"_stdout_{id}.txt")
107+
for id, stdout_file in zip(chain_ids, runset.stdout_files)
108+
)
109+
110+
cmdstan_args_other_files = CmdStanArgs(
111+
model_name='bernoulli',
112+
model_exe=exe,
113+
chain_ids=chain_ids,
114+
data=jdata,
115+
method_args=sampler_args,
116+
save_latent_dynamics=True,
117+
save_profile=True,
118+
)
119+
runset_other_files = RunSet(
120+
args=cmdstan_args_other_files, chains=4, one_process_per_chain=True
121+
)
122+
assert len(runset_other_files.diagnostic_files) == len(chain_ids)
123+
assert all(
124+
diag_file.endswith(f"_diagnostic_{id}.csv")
125+
for id, diag_file in zip(chain_ids, runset_other_files.diagnostic_files)
126+
)
127+
128+
assert len(runset_other_files.profile_files) == len(chain_ids)
129+
assert all(
130+
prof_file.endswith(f"_profile_{id}.csv")
131+
for id, prof_file in zip(chain_ids, runset_other_files.profile_files)
132+
)
133+
134+
135+
def test_output_filenames_threading() -> None:
136+
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
137+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
138+
sampler_args = SamplerArgs()
139+
chain_ids = [1, 2, 3, 4]
140+
cmdstan_args = CmdStanArgs(
141+
model_name='bernoulli',
142+
model_exe=exe,
143+
chain_ids=chain_ids,
144+
data=jdata,
145+
method_args=sampler_args,
146+
)
147+
runset = RunSet(args=cmdstan_args, chains=4, one_process_per_chain=False)
148+
149+
assert all("bernoulli-" in csv_file for csv_file in runset.csv_files)
150+
assert all(
151+
csv_file.endswith(f"_{id}.csv")
152+
for id, csv_file in zip(chain_ids, runset.csv_files)
153+
)
154+
assert len(runset.stdout_files) == 1
155+
assert runset.stdout_files[0].endswith("_stdout.txt")
156+
157+
cmdstan_args_other_files = CmdStanArgs(
158+
model_name='bernoulli',
159+
model_exe=exe,
160+
chain_ids=chain_ids,
161+
data=jdata,
162+
method_args=sampler_args,
163+
save_latent_dynamics=True,
164+
save_profile=True,
165+
)
166+
runset_other_files = RunSet(
167+
args=cmdstan_args_other_files, chains=4, one_process_per_chain=False
168+
)
169+
assert len(runset_other_files.diagnostic_files) == len(chain_ids)
170+
assert all(
171+
diag_file.endswith(f"_diagnostic_{id}.csv")
172+
for id, diag_file in zip(chain_ids, runset_other_files.diagnostic_files)
173+
)
174+
175+
assert len(runset_other_files.profile_files) == 1
176+
assert runset_other_files.profile_files[0].endswith("_profile.csv")
101177

102178

103179
def test_commands() -> None:

0 commit comments

Comments
 (0)