|
8 | 8 | import pickle |
9 | 9 | import shutil |
10 | 10 | from test import check_present, without_import |
11 | | -from unittest.mock import MagicMock, patch |
12 | 11 |
|
13 | 12 | import numpy as np |
14 | 13 | import pandas as pd |
@@ -578,7 +577,7 @@ def test_from_optimization() -> None: |
578 | 577 |
|
579 | 578 | # stan_variable |
580 | 579 | theta = bern_gqs.stan_variable(var='theta') |
581 | | - assert theta.shape == (1,) |
| 580 | + assert theta.shape == () |
582 | 581 | y_rep = bern_gqs.stan_variable(var='y_rep') |
583 | 582 | assert y_rep.shape == (1, 10) |
584 | 583 |
|
@@ -772,34 +771,3 @@ def test_vb_xarray() -> None: |
772 | 771 | bern_gqs = model.generate_quantities(data=jdata, previous_fit=bern_fit) |
773 | 772 | with pytest.raises(RuntimeError, match="via Sampling"): |
774 | 773 | _ = bern_gqs.draws_xr() |
775 | | - |
776 | | - |
777 | | -@patch( |
778 | | - 'cmdstanpy.utils.cmdstan.cmdstan_version', |
779 | | - MagicMock(return_value=(2, 27)), |
780 | | -) |
781 | | -def test_from_non_hmc_old() -> None: |
782 | | - stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') |
783 | | - bern_model = CmdStanModel(stan_file=stan) |
784 | | - jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') |
785 | | - bern_fit_v = bern_model.variational( |
786 | | - data=jdata, |
787 | | - show_console=True, |
788 | | - require_converged=False, |
789 | | - seed=12345, |
790 | | - ) |
791 | | - |
792 | | - # gq_model |
793 | | - stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') |
794 | | - model = CmdStanModel(stan_file=stan) |
795 | | - |
796 | | - with pytest.raises(RuntimeError, match="2.31"): |
797 | | - model.generate_quantities(data=jdata, previous_fit=bern_fit_v) |
798 | | - |
799 | | - bern_fit_opt = bern_model.optimize( |
800 | | - data=jdata, |
801 | | - seed=12345, |
802 | | - ) |
803 | | - |
804 | | - with pytest.raises(RuntimeError, match="2.31"): |
805 | | - model.generate_quantities(data=jdata, previous_fit=bern_fit_opt) |
0 commit comments