Skip to content

Commit de18290

Browse files
committed
Test with both parallelism setups
1 parent ad1fde0 commit de18290

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

test/test_sample.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,8 @@ def test_from_csv_no_param_hmc() -> None:
991991
assert no_parameters_sample.draws_pd().shape == (100, 93)
992992

993993

994-
def test_custom_metric() -> None:
994+
@pytest.mark.parametrize('force_one_process_per_chain', [True, False])
995+
def test_custom_metric(force_one_process_per_chain: bool) -> None:
995996
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
996997
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
997998
bern_model = CmdStanModel(stan_file=stan)
@@ -1011,6 +1012,7 @@ def test_custom_metric() -> None:
10111012
iter_warmup=10,
10121013
iter_sampling=10,
10131014
inv_metric=jmetric,
1015+
force_one_process_per_chain=force_one_process_per_chain,
10141016
)
10151017
np.testing.assert_allclose(
10161018
fit1.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6
@@ -1027,6 +1029,7 @@ def test_custom_metric() -> None:
10271029
iter_warmup=10,
10281030
iter_sampling=10,
10291031
inv_metric=[jmetric, jmetric2],
1032+
force_one_process_per_chain=force_one_process_per_chain,
10301033
)
10311034
np.testing.assert_allclose(
10321035
fit2.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6
@@ -1043,6 +1046,7 @@ def test_custom_metric() -> None:
10431046
iter_warmup=10,
10441047
iter_sampling=10,
10451048
inv_metric=metric_dict_1,
1049+
force_one_process_per_chain=force_one_process_per_chain,
10461050
)
10471051
for i in range(4):
10481052
np.testing.assert_allclose(
@@ -1055,6 +1059,7 @@ def test_custom_metric() -> None:
10551059
iter_warmup=10,
10561060
iter_sampling=10,
10571061
inv_metric=[metric_dict_1, metric_dict_2],
1062+
force_one_process_per_chain=force_one_process_per_chain,
10581063
)
10591064
np.testing.assert_allclose(
10601065
fit4.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6
@@ -1070,6 +1075,7 @@ def test_custom_metric() -> None:
10701075
iter_warmup=10,
10711076
iter_sampling=10,
10721077
inv_metric=[np.array(metric_dict_1['inv_metric']), jmetric2],
1078+
force_one_process_per_chain=force_one_process_per_chain,
10731079
)
10741080
np.testing.assert_allclose(
10751081
fit5.inv_metric[0], metric_dict_1['inv_metric'], atol=1e-6
@@ -1090,6 +1096,7 @@ def test_custom_metric() -> None:
10901096
iter_warmup=10,
10911097
iter_sampling=10,
10921098
inv_metric=[metric_dict_1, metric_dict_2],
1099+
force_one_process_per_chain=force_one_process_per_chain,
10931100
)
10941101
# metric mismatches - (not appropriate for bernoulli)
10951102
with open(os.path.join(DATAFILES_PATH, 'metric_diag.data.json')) as fd:
@@ -1104,6 +1111,7 @@ def test_custom_metric() -> None:
11041111
iter_warmup=10,
11051112
iter_sampling=10,
11061113
inv_metric=[metric_dict_1, metric_dict_2],
1114+
force_one_process_per_chain=force_one_process_per_chain,
11071115
)
11081116
# metric dict, no "inv_metric":
11091117
some_dict = {"foo": [1, 2, 3]}
@@ -1117,6 +1125,7 @@ def test_custom_metric() -> None:
11171125
iter_warmup=100,
11181126
iter_sampling=200,
11191127
inv_metric=some_dict,
1128+
force_one_process_per_chain=force_one_process_per_chain,
11201129
)
11211130

11221131

0 commit comments

Comments
 (0)