Skip to content

Commit b98a7fd

Browse files
authored
[CmdStan 2.31] Support generate_quantities after non-sampling runs (#634)
* Checkpointing * Basic refactor, passing mypy * Testing new functionality and tweaking * Maintain column headers in pd * Fix docstring
1 parent 78a7fef commit b98a7fd

File tree

8 files changed

+570
-142
lines changed

8 files changed

+570
-142
lines changed

cmdstanpy/model.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616
from io import StringIO
1717
from multiprocessing import cpu_count
1818
from pathlib import Path
19-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
19+
from typing import (
20+
Any,
21+
Callable,
22+
Dict,
23+
Iterable,
24+
List,
25+
Mapping,
26+
Optional,
27+
TypeVar,
28+
Union,
29+
)
2030

2131
import pandas as pd
2232
from tqdm.auto import tqdm
@@ -58,6 +68,7 @@
5868
from . import progress as progbar
5969

6070
OptionalPath = Union[str, os.PathLike, None]
71+
Fit = TypeVar('Fit', CmdStanMCMC, CmdStanMLE, CmdStanVB)
6172

6273

6374
class CmdStanModel:
@@ -1202,22 +1213,26 @@ def sample(
12021213
def generate_quantities(
12031214
self,
12041215
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
1205-
mcmc_sample: Union[CmdStanMCMC, List[str], None] = None,
1216+
previous_fit: Union[Fit, List[str], None] = None,
12061217
seed: Optional[int] = None,
12071218
gq_output_dir: OptionalPath = None,
12081219
sig_figs: Optional[int] = None,
12091220
show_console: bool = False,
12101221
refresh: Optional[int] = None,
12111222
time_fmt: str = "%Y%m%d%H%M%S",
12121223
timeout: Optional[float] = None,
1213-
) -> CmdStanGQ:
1224+
*,
1225+
mcmc_sample: Union[CmdStanMCMC, List[str], None] = None,
1226+
) -> CmdStanGQ[Fit]:
12141227
"""
12151228
Run CmdStan's generate_quantities method which runs the generated
12161229
quantities block of a model given an existing sample.
12171230
1218-
This function takes a :class:`CmdStanMCMC` object and the dataset used
1219-
to generate that sample and calls to the CmdStan ``generate_quantities``
1220-
method to generate additional quantities of interest.
1231+
This function takes one of the Stan fit objects
1232+
:class:`CmdStanMCMC`, :class:`CmdStanMLE`, or :class:`CmdStanVB`
1233+
and the data required for the model and calls to the CmdStan
1234+
``generate_quantities`` method to generate additional quantities of
1235+
interest.
12211236
12221237
The :class:`CmdStanGQ` object records the command, the return code,
12231238
and the paths to the generate method output CSV and console files.
@@ -1236,9 +1251,10 @@ def generate_quantities(
12361251
either as a dictionary with entries matching the data variables,
12371252
or as the path of a data file in JSON or Rdump format.
12381253
1239-
:param mcmc_sample: Can be either a :class:`CmdStanMCMC` object returned
1240-
by the :meth:`sample` method or a list of stan-csv files generated
1241-
by fitting the model to the data using any Stan interface.
1254+
:param previous_fit: Can be either a :class:`CmdStanMCMC`,
1255+
:class:`CmdStanMLE`, or :class:`CmdStanVB` or a list of
1256+
stan-csv files generated by fitting the model to the data
1257+
using any Stan interface.
12421258
12431259
:param seed: The seed for random number generator. Must be an integer
12441260
between 0 and 2^32 - 1. If unspecified,
@@ -1272,39 +1288,64 @@ def generate_quantities(
12721288
12731289
:return: CmdStanGQ object
12741290
"""
1275-
if isinstance(mcmc_sample, CmdStanMCMC):
1276-
mcmc_fit = mcmc_sample
1277-
sample_csv_files = mcmc_sample.runset.csv_files
1278-
elif isinstance(mcmc_sample, list):
1279-
if len(mcmc_sample) < 1:
1291+
if mcmc_sample is not None:
1292+
if previous_fit:
1293+
raise ValueError(
1294+
"Cannot supply both 'previous_fit' and "
1295+
"deprecated argument 'mcmc_sample'"
1296+
)
1297+
get_logger().warning(
1298+
"Argument name `mcmc_sample` is deprecated, please "
1299+
"rename to `previous_fit`."
1300+
)
1301+
1302+
previous_fit = mcmc_sample # type: ignore
1303+
1304+
if isinstance(previous_fit, (CmdStanMCMC, CmdStanMLE, CmdStanVB)):
1305+
fit_object = previous_fit
1306+
fit_csv_files = previous_fit.runset.csv_files
1307+
elif isinstance(previous_fit, list):
1308+
if len(previous_fit) < 1:
12801309
raise ValueError(
12811310
'Expecting list of Stan CSV files, found empty list'
12821311
)
12831312
try:
1284-
sample_csv_files = mcmc_sample
1285-
sample_fit = from_csv(sample_csv_files)
1286-
mcmc_fit = sample_fit # type: ignore
1313+
fit_csv_files = previous_fit
1314+
fit_object = from_csv(fit_csv_files) # type: ignore
12871315
except ValueError as e:
12881316
raise ValueError(
12891317
'Invalid sample from Stan CSV files, error:\n\t{}\n\t'
12901318
' while processing files\n\t{}'.format(
1291-
repr(e), '\n\t'.join(mcmc_sample)
1319+
repr(e), '\n\t'.join(previous_fit)
12921320
)
12931321
) from e
12941322
else:
12951323
raise ValueError(
1296-
'MCMC sample must be either CmdStanMCMC object'
1297-
' or list of paths to sample Stan CSV files.'
1298-
)
1299-
chains = mcmc_fit.chains
1300-
chain_ids = mcmc_fit.chain_ids
1301-
if mcmc_fit.metadata.cmdstan_config['save_warmup']:
1302-
get_logger().warning(
1303-
'Sample contains saved warmup draws which will be used '
1304-
'to generate additional quantities of interest.'
1324+
'Previous fit must be either CmdStanPy fit object'
1325+
' or list of paths to Stan CSV files.'
13051326
)
1327+
if isinstance(fit_object, CmdStanMCMC):
1328+
chains = fit_object.chains
1329+
chain_ids = fit_object.chain_ids
1330+
if fit_object._save_warmup:
1331+
get_logger().warning(
1332+
'Sample contains saved warmup draws which will be used '
1333+
'to generate additional quantities of interest.'
1334+
)
1335+
elif isinstance(fit_object, CmdStanMLE):
1336+
chains = 1
1337+
chain_ids = [1]
1338+
if fit_object._save_iterations:
1339+
get_logger().warning(
1340+
'MLE contains saved iterations which will be used '
1341+
'to generate additional quantities of interest.'
1342+
)
1343+
else: # isinstance(fit_object, CmdStanVB)
1344+
chains = 1
1345+
chain_ids = [1]
1346+
13061347
generate_quantities_args = GenerateQuantitiesArgs(
1307-
csv_files=sample_csv_files
1348+
csv_files=fit_csv_files
13081349
)
13091350
generate_quantities_args.validate(chains)
13101351
with MaybeDictToFilePath(data, None) as (_data, _inits):
@@ -1345,7 +1386,7 @@ def generate_quantities(
13451386
+ ' output is unclear!'
13461387
)
13471388
raise RuntimeError(msg)
1348-
quantities = CmdStanGQ(runset=runset, mcmc_sample=mcmc_fit)
1389+
quantities = CmdStanGQ(runset=runset, previous_fit=fit_object)
13491390
return quantities
13501391

13511392
def variational(

0 commit comments

Comments
 (0)