@@ -36,7 +36,7 @@ def test_pathfinder_outputs():
3636
3737 assert pathfinder .is_resampled
3838
39- assert pathfinder .draws ().shape == (draws , 3 )
39+ assert pathfinder .draws ().shape == (draws , 4 )
4040
4141
4242def test_pathfinder_from_csv ():
@@ -159,7 +159,7 @@ def test_pathfinder_no_psis():
159159 pathfinder = bern_model .pathfinder (data = jdata , psis_resample = False )
160160
161161 assert not pathfinder .is_resampled
162- assert pathfinder .draws ().shape == (4000 , 3 )
162+ assert pathfinder .draws ().shape == (4000 , 4 )
163163
164164
165165def test_pathfinder_no_lp_calc ():
@@ -170,7 +170,7 @@ def test_pathfinder_no_lp_calc():
170170 pathfinder = bern_model .pathfinder (data = jdata , calculate_lp = False )
171171
172172 assert not pathfinder .is_resampled
173- assert pathfinder .draws ().shape == (4000 , 3 )
173+ assert pathfinder .draws ().shape == (4000 , 4 )
174174 n_lp_nan = np .sum (np .isnan (pathfinder .method_variables ()['lp__' ]))
175175 assert n_lp_nan < 4000 # some lp still calculated during pathfinder
176176 assert n_lp_nan > 3000 # but most are not
@@ -190,4 +190,4 @@ def test_pathfinder_threads():
190190 stan_file = stan , cpp_options = {'STAN_THREADS' : True }, force_compile = True
191191 )
192192 pathfinder = bern_model .pathfinder (data = jdata , num_threads = 4 )
193- assert pathfinder .draws ().shape == (1000 , 3 )
193+ assert pathfinder .draws ().shape == (1000 , 4 )
0 commit comments