Skip to content

Commit 6a44ec4

Browse files
committed
New gq tests
1 parent cf624de commit 6a44ec4

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

cmdstanpy/stanfit/gq.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,12 +690,14 @@ def _previous_draws(self, inc_warmup: bool) -> np.ndarray:
690690
return np.atleast_2d( # type: ignore
691691
p_fit.optimized_params_np,
692692
)[:, None]
693-
else: # CmdStanVB:
693+
elif isinstance(p_fit, CmdStanVB):
694694
if inc_warmup:
695695
return np.vstack(
696696
[p_fit.variational_params_np, p_fit.variational_sample]
697697
)[:, None]
698698
return p_fit.variational_sample[:, None]
699+
else: # CmdStanLaplace, CmdStanPathfinder
700+
return p_fit.draws()[:, None, :]
699701

700702
def _previous_draws_pd(
701703
self, vars: list[str], inc_warmup: bool
@@ -714,8 +716,12 @@ def _previous_draws_pd(
714716
return p_fit.optimized_iterations_pd[sel] # type: ignore
715717
else:
716718
return p_fit.optimized_params_pd[sel]
717-
else: # CmdStanVB:
719+
elif isinstance(p_fit, CmdStanVB):
718720
return p_fit.variational_sample_pd[sel]
721+
elif isinstance(p_fit, CmdStanLaplace):
722+
return p_fit.draws_pd(vars or None)
723+
else: # CmdStanPathfinder
724+
return pd.DataFrame(p_fit.draws(), columns=p_fit.column_names)[sel]
719725

720726
def save_csvfiles(self, dir: str | None = None) -> None:
721727
"""

test/test_generate_quantities.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,80 @@ def test_vb_xarray() -> None:
771771
bern_gqs = model.generate_quantities(data=jdata, previous_fit=bern_fit)
772772
with pytest.raises(RuntimeError, match="via Sampling"):
773773
_ = bern_gqs.draws_xr()
774+
775+
776+
def test_from_pathfinder() -> None:
777+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
778+
bern_model = CmdStanModel(stan_file=stan)
779+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
780+
bern_fit = bern_model.pathfinder(
781+
data=jdata,
782+
seed=12345,
783+
)
784+
785+
# gq_model
786+
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
787+
model = CmdStanModel(stan_file=stan)
788+
789+
bern_gqs = model.generate_quantities(data=jdata, previous_fit=bern_fit)
790+
791+
assert bern_gqs.runset._args.method == Method.GENERATE_QUANTITIES
792+
assert 'CmdStanGQ: model=bernoulli_ppc' in repr(bern_gqs)
793+
assert 'method=generate_quantities' in repr(bern_gqs)
794+
assert bern_gqs.runset.chains == 1
795+
assert bern_gqs.runset._retcode(0) == 0
796+
csv_file = bern_gqs.runset.csv_files[0]
797+
assert os.path.exists(csv_file)
798+
799+
assert bern_gqs.draws().shape == (1000, 1, 10)
800+
assert bern_gqs.draws(inc_sample=True).shape == (1000, 1, 14)
801+
802+
# draws_pd()
803+
assert bern_gqs.draws_pd().shape == (1000, 13)
804+
805+
# stan_variable
806+
theta = bern_gqs.stan_variable(var='theta')
807+
assert theta.shape == (1000,)
808+
y_rep = bern_gqs.stan_variable(var='y_rep')
809+
assert y_rep.shape == (1000, 10)
810+
811+
812+
def test_from_laplace() -> None:
813+
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
814+
bern_model = CmdStanModel(stan_file=stan)
815+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
816+
bern_fit = bern_model.laplace_sample(
817+
data=jdata,
818+
seed=12345,
819+
)
820+
821+
# gq_model
822+
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
823+
model = CmdStanModel(stan_file=stan)
824+
825+
bern_gqs = model.generate_quantities(data=jdata, previous_fit=bern_fit)
826+
827+
assert bern_gqs.runset._args.method == Method.GENERATE_QUANTITIES
828+
assert 'CmdStanGQ: model=bernoulli_ppc' in repr(bern_gqs)
829+
assert 'method=generate_quantities' in repr(bern_gqs)
830+
assert bern_gqs.runset.chains == 1
831+
assert bern_gqs.runset._retcode(0) == 0
832+
csv_file = bern_gqs.runset.csv_files[0]
833+
assert os.path.exists(csv_file)
834+
835+
assert bern_gqs.draws().shape == (1000, 1, 10)
836+
assert bern_gqs.draws(inc_sample=True).shape == (1000, 1, 13)
837+
838+
# draws_pd()
839+
assert bern_gqs.draws_pd().shape == (1000, 13)
840+
assert (
841+
bern_gqs.draws_pd(inc_sample=True).shape[1]
842+
== bern_gqs.previous_fit.draws_pd().shape[1]
843+
+ bern_gqs.draws_pd().shape[1]
844+
)
845+
846+
# stan_variable
847+
theta = bern_gqs.stan_variable(var='theta')
848+
assert theta.shape == (1000,)
849+
y_rep = bern_gqs.stan_variable(var='y_rep')
850+
assert y_rep.shape == (1000, 10)

0 commit comments

Comments
 (0)