Skip to content

Commit cc6daf2

Browse files
authored
Merge pull request #551 from stan-dev/fix/stansummary-defaults
Update default precision for summary()
2 parents 6edb7ad + 902ec71 commit cc6daf2

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
List,
1515
MutableMapping,
1616
Optional,
17+
Sequence,
1718
Tuple,
1819
Union,
1920
)
@@ -390,8 +391,8 @@ def _assemble_draws(self) -> None:
390391

391392
def summary(
392393
self,
393-
percentiles: Optional[List[int]] = None,
394-
sig_figs: Optional[int] = None,
394+
percentiles: Sequence[int] = (5, 50, 95),
395+
sig_figs: int = 6,
395396
) -> pd.DataFrame:
396397
"""
397398
Run cmdstan/bin/stansummary over all output CSV files, assemble
@@ -401,8 +402,9 @@ def summary(
401402
quantities variables listed in the order in which they were declared
402403
in the Stan program.
403404
404-
:param percentiles: Ordered non-empty list of percentiles to report.
405-
Must be integers from (1, 99), inclusive.
405+
:param percentiles: Ordered non-empty sequence of percentiles to report.
406+
Must be integers from (1, 99), inclusive. Defaults to
407+
``(5, 50, 95)``
406408
407409
:param sig_figs: Number of significant figures to report.
408410
Must be an integer between 1 and 18. If unspecified, the default
@@ -413,40 +415,38 @@ def summary(
413415
414416
:return: pandas.DataFrame
415417
"""
416-
percentiles_str = '--percentiles=5,50,95'
417-
if percentiles is not None:
418-
if len(percentiles) == 0:
418+
419+
if len(percentiles) == 0:
420+
raise ValueError(
421+
'Invalid percentiles argument, must be ordered'
422+
' non-empty list from (1, 99), inclusive.'
423+
)
424+
cur_pct = 0
425+
for pct in percentiles:
426+
if pct > 99 or not pct > cur_pct:
419427
raise ValueError(
420-
'Invalid percentiles argument, must be ordered'
428+
'Invalid percentiles spec, must be ordered'
421429
' non-empty list from (1, 99), inclusive.'
422430
)
423-
cur_pct = 0
424-
for pct in percentiles:
425-
if pct > 99 or not pct > cur_pct:
426-
raise ValueError(
427-
'Invalid percentiles spec, must be ordered'
428-
' non-empty list from (1, 99), inclusive.'
429-
)
430-
cur_pct = pct
431-
percentiles_str = '='.join(
432-
['--percentiles', ','.join([str(x) for x in percentiles])]
431+
cur_pct = pct
432+
percentiles_str = (
433+
f"--percentiles= {','.join(str(x) for x in percentiles)}"
434+
)
435+
436+
if not isinstance(sig_figs, int) or sig_figs < 1 or sig_figs > 18:
437+
raise ValueError(
438+
'Keyword "sig_figs" must be an integer between 1 and 18,'
439+
' found {}'.format(sig_figs)
433440
)
434-
sig_figs_str = '--sig_figs=2'
435-
if sig_figs is not None:
436-
if not isinstance(sig_figs, int) or sig_figs < 1 or sig_figs > 18:
437-
raise ValueError(
438-
'Keyword "sig_figs" must be an integer between 1 and 18,'
439-
' found {}'.format(sig_figs)
440-
)
441-
csv_sig_figs = self._sig_figs or 6
442-
if sig_figs > csv_sig_figs:
443-
get_logger().warning(
444-
'Requesting %d significant digits of output, but CSV files'
445-
' only have %d digits of precision.',
446-
sig_figs,
447-
csv_sig_figs,
448-
)
449-
sig_figs_str = '--sig_figs=' + str(sig_figs)
441+
csv_sig_figs = self._sig_figs or 6
442+
if sig_figs > csv_sig_figs:
443+
get_logger().warning(
444+
'Requesting %d significant digits of output, but CSV files'
445+
' only have %d digits of precision.',
446+
sig_figs,
447+
csv_sig_figs,
448+
)
449+
sig_figs_str = f'--sig_figs={sig_figs}'
450450
cmd_path = os.path.join(
451451
cmdstan_path(), 'bin', 'stansummary' + EXTENSION
452452
)

0 commit comments

Comments
 (0)