Skip to content

Commit fad7a69

Browse files
authored
Merge pull request #749 from stan-dev/fix/2.35-fixes
Fix tests for new cmdstan
2 parents cbea79f + 1af8596 commit fad7a69

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

cmdstanpy/stanfit/pathfinder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ def is_resampled(self) -> bool:
214214
"""
215215
return ( # type: ignore
216216
self._metadata.cmdstan_config.get("num_paths", 4) > 1
217-
and self._metadata.cmdstan_config.get('psis_resample', 1) == 1
218-
and self._metadata.cmdstan_config.get('calculate_lp', 1) == 1
217+
and self._metadata.cmdstan_config.get('psis_resample', 1)
218+
in (1, 'true')
219+
and self._metadata.cmdstan_config.get('calculate_lp', 1)
220+
in (1, 'true')
219221
)
220222

221223
def save_csvfiles(self, dir: Optional[str] = None) -> None:

cmdstanpy/utils/stancsv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def check_sampler_csv(
5252
)
5353
)
5454
if save_warmup:
55-
if not ('save_warmup' in meta and meta['save_warmup'] == 1):
55+
if not ('save_warmup' in meta and meta['save_warmup'] in (1, 'true')):
5656
raise ValueError(
5757
'bad Stan CSV file {}, '
5858
'config error, expected save_warmup = 1'.format(path)

test/test_optimize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,13 @@ def test_variables_3d() -> None:
262262
)
263263
vars_iters = multidim_mle_iters.stan_variables(inc_iterations=True)
264264
assert len(vars_iters) == len(multidim_mle_iters.metadata.stan_vars)
265+
assert 'frac_60' in vars_iters
266+
n_iter = vars_iters['frac_60'].shape[0]
267+
assert n_iter > 1
265268
assert 'y_rep' in vars_iters
266-
assert vars_iters['y_rep'].shape == (8, 5, 4, 3)
269+
assert vars_iters['y_rep'].shape == (n_iter, 5, 4, 3)
267270
assert 'beta' in vars_iters
268-
assert vars_iters['beta'].shape == (8, 2)
269-
assert 'frac_60' in vars_iters
270-
assert vars_iters['frac_60'].shape == (8,)
271+
assert vars_iters['beta'].shape == (n_iter, 2)
271272

272273

273274
def test_optimize_good() -> None:

0 commit comments

Comments
 (0)