1616from io import StringIO
1717from multiprocessing import cpu_count
1818from pathlib import Path
19- from typing import Any , Callable , Dict , Iterable , List , Mapping , Optional , Union
19+ from typing import (
20+ Any ,
21+ Callable ,
22+ Dict ,
23+ Iterable ,
24+ List ,
25+ Mapping ,
26+ Optional ,
27+ TypeVar ,
28+ Union ,
29+ )
2030
2131import pandas as pd
2232from tqdm .auto import tqdm
5868from . import progress as progbar
5969
6070OptionalPath = Union [str , os .PathLike , None ]
71+ Fit = TypeVar ('Fit' , CmdStanMCMC , CmdStanMLE , CmdStanVB )
6172
6273
6374class CmdStanModel :
@@ -1202,22 +1213,26 @@ def sample(
12021213 def generate_quantities (
12031214 self ,
12041215 data : Union [Mapping [str , Any ], str , os .PathLike , None ] = None ,
1205- mcmc_sample : Union [CmdStanMCMC , List [str ], None ] = None ,
1216+ previous_fit : Union [Fit , List [str ], None ] = None ,
12061217 seed : Optional [int ] = None ,
12071218 gq_output_dir : OptionalPath = None ,
12081219 sig_figs : Optional [int ] = None ,
12091220 show_console : bool = False ,
12101221 refresh : Optional [int ] = None ,
12111222 time_fmt : str = "%Y%m%d%H%M%S" ,
12121223 timeout : Optional [float ] = None ,
1213- ) -> CmdStanGQ :
1224+ * ,
1225+ mcmc_sample : Union [CmdStanMCMC , List [str ], None ] = None ,
1226+ ) -> CmdStanGQ [Fit ]:
12141227 """
12151228 Run CmdStan's generate_quantities method which runs the generated
12161229 quantities block of a model given an existing sample.
12171230
1218- This function takes a :class:`CmdStanMCMC` object and the dataset used
1219- to generate that sample and calls to the CmdStan ``generate_quantities``
1220- method to generate additional quantities of interest.
1231+ This function takes one of the Stan fit objects
1232+ :class:`CmdStanMCMC`, :class:`CmdStanMLE`, or :class:`CmdStanVB`
1233+ and the data required for the model and calls to the CmdStan
1234+ ``generate_quantities`` method to generate additional quantities of
1235+ interest.
12211236
12221237 The :class:`CmdStanGQ` object records the command, the return code,
12231238 and the paths to the generate method output CSV and console files.
@@ -1236,9 +1251,10 @@ def generate_quantities(
12361251 either as a dictionary with entries matching the data variables,
12371252 or as the path of a data file in JSON or Rdump format.
12381253
1239- :param mcmc_sample: Can be either a :class:`CmdStanMCMC` object returned
1240- by the :meth:`sample` method or a list of stan-csv files generated
1241- by fitting the model to the data using any Stan interface.
1254+ :param previous_fit: Can be either a :class:`CmdStanMCMC`,
1255+ :class:`CmdStanMLE`, or :class:`CmdStanVB` or a list of
1256+ stan-csv files generated by fitting the model to the data
1257+ using any Stan interface.
12421258
12431259 :param seed: The seed for random number generator. Must be an integer
12441260 between 0 and 2^32 - 1. If unspecified,
@@ -1272,39 +1288,64 @@ def generate_quantities(
12721288
12731289 :return: CmdStanGQ object
12741290 """
1275- if isinstance (mcmc_sample , CmdStanMCMC ):
1276- mcmc_fit = mcmc_sample
1277- sample_csv_files = mcmc_sample .runset .csv_files
1278- elif isinstance (mcmc_sample , list ):
1279- if len (mcmc_sample ) < 1 :
1291+ if mcmc_sample is not None :
1292+ if previous_fit :
1293+ raise ValueError (
1294+ "Cannot supply both 'previous_fit' and "
1295+ "deprecated argument 'mcmc_sample'"
1296+ )
1297+ get_logger ().warning (
1298+ "Argument name `mcmc_sample` is deprecated, please "
1299+ "rename to `previous_fit`."
1300+ )
1301+
1302+ previous_fit = mcmc_sample # type: ignore
1303+
1304+ if isinstance (previous_fit , (CmdStanMCMC , CmdStanMLE , CmdStanVB )):
1305+ fit_object = previous_fit
1306+ fit_csv_files = previous_fit .runset .csv_files
1307+ elif isinstance (previous_fit , list ):
1308+ if len (previous_fit ) < 1 :
12801309 raise ValueError (
12811310 'Expecting list of Stan CSV files, found empty list'
12821311 )
12831312 try :
1284- sample_csv_files = mcmc_sample
1285- sample_fit = from_csv (sample_csv_files )
1286- mcmc_fit = sample_fit # type: ignore
1313+ fit_csv_files = previous_fit
1314+ fit_object = from_csv (fit_csv_files ) # type: ignore
12871315 except ValueError as e :
12881316 raise ValueError (
12891317 'Invalid sample from Stan CSV files, error:\n \t {}\n \t '
12901318 ' while processing files\n \t {}' .format (
1291- repr (e ), '\n \t ' .join (mcmc_sample )
1319+ repr (e ), '\n \t ' .join (previous_fit )
12921320 )
12931321 ) from e
12941322 else :
12951323 raise ValueError (
1296- 'MCMC sample must be either CmdStanMCMC object'
1297- ' or list of paths to sample Stan CSV files.'
1298- )
1299- chains = mcmc_fit .chains
1300- chain_ids = mcmc_fit .chain_ids
1301- if mcmc_fit .metadata .cmdstan_config ['save_warmup' ]:
1302- get_logger ().warning (
1303- 'Sample contains saved warmup draws which will be used '
1304- 'to generate additional quantities of interest.'
1324+ 'Previous fit must be either CmdStanPy fit object'
1325+ ' or list of paths to Stan CSV files.'
13051326 )
1327+ if isinstance (fit_object , CmdStanMCMC ):
1328+ chains = fit_object .chains
1329+ chain_ids = fit_object .chain_ids
1330+ if fit_object ._save_warmup :
1331+ get_logger ().warning (
1332+ 'Sample contains saved warmup draws which will be used '
1333+ 'to generate additional quantities of interest.'
1334+ )
1335+ elif isinstance (fit_object , CmdStanMLE ):
1336+ chains = 1
1337+ chain_ids = [1 ]
1338+ if fit_object ._save_iterations :
1339+ get_logger ().warning (
1340+ 'MLE contains saved iterations which will be used '
1341+ 'to generate additional quantities of interest.'
1342+ )
1343+ else : # isinstance(fit_object, CmdStanVB)
1344+ chains = 1
1345+ chain_ids = [1 ]
1346+
13061347 generate_quantities_args = GenerateQuantitiesArgs (
1307- csv_files = sample_csv_files
1348+ csv_files = fit_csv_files
13081349 )
13091350 generate_quantities_args .validate (chains )
13101351 with MaybeDictToFilePath (data , None ) as (_data , _inits ):
@@ -1345,7 +1386,7 @@ def generate_quantities(
13451386 + ' output is unclear!'
13461387 )
13471388 raise RuntimeError (msg )
1348- quantities = CmdStanGQ (runset = runset , mcmc_sample = mcmc_fit )
1389+ quantities = CmdStanGQ (runset = runset , previous_fit = fit_object )
13491390 return quantities
13501391
13511392 def variational (
0 commit comments