@@ -771,3 +771,80 @@ def test_vb_xarray() -> None:
771771 bern_gqs = model .generate_quantities (data = jdata , previous_fit = bern_fit )
772772 with pytest .raises (RuntimeError , match = "via Sampling" ):
773773 _ = bern_gqs .draws_xr ()
774+
775+
776+ def test_from_pathfinder () -> None :
777+ stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
778+ bern_model = CmdStanModel (stan_file = stan )
779+ jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
780+ bern_fit = bern_model .pathfinder (
781+ data = jdata ,
782+ seed = 12345 ,
783+ )
784+
785+ # gq_model
786+ stan = os .path .join (DATAFILES_PATH , 'bernoulli_ppc.stan' )
787+ model = CmdStanModel (stan_file = stan )
788+
789+ bern_gqs = model .generate_quantities (data = jdata , previous_fit = bern_fit )
790+
791+ assert bern_gqs .runset ._args .method == Method .GENERATE_QUANTITIES
792+ assert 'CmdStanGQ: model=bernoulli_ppc' in repr (bern_gqs )
793+ assert 'method=generate_quantities' in repr (bern_gqs )
794+ assert bern_gqs .runset .chains == 1
795+ assert bern_gqs .runset ._retcode (0 ) == 0
796+ csv_file = bern_gqs .runset .csv_files [0 ]
797+ assert os .path .exists (csv_file )
798+
799+ assert bern_gqs .draws ().shape == (1000 , 1 , 10 )
800+ assert bern_gqs .draws (inc_sample = True ).shape == (1000 , 1 , 14 )
801+
802+ # draws_pd()
803+ assert bern_gqs .draws_pd ().shape == (1000 , 13 )
804+
805+ # stan_variable
806+ theta = bern_gqs .stan_variable (var = 'theta' )
807+ assert theta .shape == (1000 ,)
808+ y_rep = bern_gqs .stan_variable (var = 'y_rep' )
809+ assert y_rep .shape == (1000 , 10 )
810+
811+
812+ def test_from_laplace () -> None :
813+ stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
814+ bern_model = CmdStanModel (stan_file = stan )
815+ jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
816+ bern_fit = bern_model .laplace_sample (
817+ data = jdata ,
818+ seed = 12345 ,
819+ )
820+
821+ # gq_model
822+ stan = os .path .join (DATAFILES_PATH , 'bernoulli_ppc.stan' )
823+ model = CmdStanModel (stan_file = stan )
824+
825+ bern_gqs = model .generate_quantities (data = jdata , previous_fit = bern_fit )
826+
827+ assert bern_gqs .runset ._args .method == Method .GENERATE_QUANTITIES
828+ assert 'CmdStanGQ: model=bernoulli_ppc' in repr (bern_gqs )
829+ assert 'method=generate_quantities' in repr (bern_gqs )
830+ assert bern_gqs .runset .chains == 1
831+ assert bern_gqs .runset ._retcode (0 ) == 0
832+ csv_file = bern_gqs .runset .csv_files [0 ]
833+ assert os .path .exists (csv_file )
834+
835+ assert bern_gqs .draws ().shape == (1000 , 1 , 10 )
836+ assert bern_gqs .draws (inc_sample = True ).shape == (1000 , 1 , 13 )
837+
838+ # draws_pd()
839+ assert bern_gqs .draws_pd ().shape == (1000 , 13 )
840+ assert (
841+ bern_gqs .draws_pd (inc_sample = True ).shape [1 ]
842+ == bern_gqs .previous_fit .draws_pd ().shape [1 ]
843+ + bern_gqs .draws_pd ().shape [1 ]
844+ )
845+
846+ # stan_variable
847+ theta = bern_gqs .stan_variable (var = 'theta' )
848+ assert theta .shape == (1000 ,)
849+ y_rep = bern_gqs .stan_variable (var = 'y_rep' )
850+ assert y_rep .shape == (1000 , 10 )
0 commit comments