Skip to content

Commit e3114bb

Browse files
committed
Test fixes
1 parent fd721ff commit e3114bb

File tree

2 files changed

+5
-34
lines changed

2 files changed

+5
-34
lines changed

cmdstanpy/stanfit/gq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,11 +664,14 @@ def _draws_start(self, inc_warmup: bool) -> tuple[int, int]:
664664
num_draws = opt_iters
665665
else:
666666
draw1 = opt_iters - 1
667-
else: # CmdStanVB:
667+
elif isinstance(p_fit, CmdStanVB):
668668
draw1 = 1 # skip mean
669669
num_draws = p_fit.variational_sample.shape[0]
670670
if inc_warmup:
671671
num_draws += 1
672+
else:
673+
num_draws = p_fit.draws().shape[0]
674+
draw1 = 0
672675

673676
return draw1, num_draws
674677

test/test_generate_quantities.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pickle
99
import shutil
1010
from test import check_present, without_import
11-
from unittest.mock import MagicMock, patch
1211

1312
import numpy as np
1413
import pandas as pd
@@ -578,7 +577,7 @@ def test_from_optimization() -> None:
578577

579578
# stan_variable
580579
theta = bern_gqs.stan_variable(var='theta')
581-
assert theta.shape == (1,)
580+
assert theta.shape == ()
582581
y_rep = bern_gqs.stan_variable(var='y_rep')
583582
assert y_rep.shape == (1, 10)
584583

@@ -772,34 +771,3 @@ def test_vb_xarray() -> None:
772771
bern_gqs = model.generate_quantities(data=jdata, previous_fit=bern_fit)
773772
with pytest.raises(RuntimeError, match="via Sampling"):
774773
_ = bern_gqs.draws_xr()
775-
776-
777-
@patch(
778-
'cmdstanpy.utils.cmdstan.cmdstan_version',
779-
MagicMock(return_value=(2, 27)),
780-
)
781-
def test_from_non_hmc_old() -> None:
782-
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
783-
bern_model = CmdStanModel(stan_file=stan)
784-
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
785-
bern_fit_v = bern_model.variational(
786-
data=jdata,
787-
show_console=True,
788-
require_converged=False,
789-
seed=12345,
790-
)
791-
792-
# gq_model
793-
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
794-
model = CmdStanModel(stan_file=stan)
795-
796-
with pytest.raises(RuntimeError, match="2.31"):
797-
model.generate_quantities(data=jdata, previous_fit=bern_fit_v)
798-
799-
bern_fit_opt = bern_model.optimize(
800-
data=jdata,
801-
seed=12345,
802-
)
803-
804-
with pytest.raises(RuntimeError, match="2.31"):
805-
model.generate_quantities(data=jdata, previous_fit=bern_fit_opt)

0 commit comments

Comments
 (0)