@@ -991,7 +991,8 @@ def test_from_csv_no_param_hmc() -> None:
991991 assert no_parameters_sample .draws_pd ().shape == (100 , 93 )
992992
993993
994- def test_custom_metric () -> None :
994+ @pytest .mark .parametrize ('force_one_process_per_chain' , [True , False ])
995+ def test_custom_metric (force_one_process_per_chain : bool ) -> None :
995996 stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
996997 jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
997998 bern_model = CmdStanModel (stan_file = stan )
@@ -1011,6 +1012,7 @@ def test_custom_metric() -> None:
10111012 iter_warmup = 10 ,
10121013 iter_sampling = 10 ,
10131014 inv_metric = jmetric ,
1015+ force_one_process_per_chain = force_one_process_per_chain ,
10141016 )
10151017 np .testing .assert_allclose (
10161018 fit1 .inv_metric [0 ], metric_dict_1 ['inv_metric' ], atol = 1e-6
@@ -1027,6 +1029,7 @@ def test_custom_metric() -> None:
10271029 iter_warmup = 10 ,
10281030 iter_sampling = 10 ,
10291031 inv_metric = [jmetric , jmetric2 ],
1032+ force_one_process_per_chain = force_one_process_per_chain ,
10301033 )
10311034 np .testing .assert_allclose (
10321035 fit2 .inv_metric [0 ], metric_dict_1 ['inv_metric' ], atol = 1e-6
@@ -1043,6 +1046,7 @@ def test_custom_metric() -> None:
10431046 iter_warmup = 10 ,
10441047 iter_sampling = 10 ,
10451048 inv_metric = metric_dict_1 ,
1049+ force_one_process_per_chain = force_one_process_per_chain ,
10461050 )
10471051 for i in range (4 ):
10481052 np .testing .assert_allclose (
@@ -1055,6 +1059,7 @@ def test_custom_metric() -> None:
10551059 iter_warmup = 10 ,
10561060 iter_sampling = 10 ,
10571061 inv_metric = [metric_dict_1 , metric_dict_2 ],
1062+ force_one_process_per_chain = force_one_process_per_chain ,
10581063 )
10591064 np .testing .assert_allclose (
10601065 fit4 .inv_metric [0 ], metric_dict_1 ['inv_metric' ], atol = 1e-6
@@ -1070,6 +1075,7 @@ def test_custom_metric() -> None:
10701075 iter_warmup = 10 ,
10711076 iter_sampling = 10 ,
10721077 inv_metric = [np .array (metric_dict_1 ['inv_metric' ]), jmetric2 ],
1078+ force_one_process_per_chain = force_one_process_per_chain ,
10731079 )
10741080 np .testing .assert_allclose (
10751081 fit5 .inv_metric [0 ], metric_dict_1 ['inv_metric' ], atol = 1e-6
@@ -1090,6 +1096,7 @@ def test_custom_metric() -> None:
10901096 iter_warmup = 10 ,
10911097 iter_sampling = 10 ,
10921098 inv_metric = [metric_dict_1 , metric_dict_2 ],
1099+ force_one_process_per_chain = force_one_process_per_chain ,
10931100 )
10941101 # metric mismatches - (not appropriate for bernoulli)
10951102 with open (os .path .join (DATAFILES_PATH , 'metric_diag.data.json' )) as fd :
@@ -1104,6 +1111,7 @@ def test_custom_metric() -> None:
11041111 iter_warmup = 10 ,
11051112 iter_sampling = 10 ,
11061113 inv_metric = [metric_dict_1 , metric_dict_2 ],
1114+ force_one_process_per_chain = force_one_process_per_chain ,
11071115 )
11081116 # metric dict, no "inv_metric":
11091117 some_dict = {"foo" : [1 , 2 , 3 ]}
@@ -1117,6 +1125,7 @@ def test_custom_metric() -> None:
11171125 iter_warmup = 100 ,
11181126 iter_sampling = 200 ,
11191127 inv_metric = some_dict ,
1128+ force_one_process_per_chain = force_one_process_per_chain ,
11201129 )
11211130
11221131
0 commit comments