Skip to content

Commit 902ec71

Browse files
committed
Update default precision for summary()
1 parent ee3692e commit 902ec71

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
List,
1515
MutableMapping,
1616
Optional,
17+
Sequence,
1718
Tuple,
1819
Union,
1920
)
@@ -382,8 +383,8 @@ def _assemble_draws(self) -> None:
382383

383384
def summary(
384385
self,
385-
percentiles: Optional[List[int]] = None,
386-
sig_figs: Optional[int] = None,
386+
percentiles: Sequence[int] = (5, 50, 95),
387+
sig_figs: int = 6,
387388
) -> pd.DataFrame:
388389
"""
389390
Run cmdstan/bin/stansummary over all output CSV files, assemble
@@ -393,8 +394,9 @@ def summary(
393394
quantities variables listed in the order in which they were declared
394395
in the Stan program.
395396
396-
:param percentiles: Ordered non-empty list of percentiles to report.
397-
Must be integers from (1, 99), inclusive.
397+
:param percentiles: Ordered non-empty sequence of percentiles to report.
398+
Must be integers from (1, 99), inclusive. Defaults to
399+
``(5, 50, 95)``
398400
399401
:param sig_figs: Number of significant figures to report.
400402
Must be an integer between 1 and 18. If unspecified, the default
@@ -405,40 +407,38 @@ def summary(
405407
406408
:return: pandas.DataFrame
407409
"""
408-
percentiles_str = '--percentiles=5,50,95'
409-
if percentiles is not None:
410-
if len(percentiles) == 0:
410+
411+
if len(percentiles) == 0:
412+
raise ValueError(
413+
'Invalid percentiles argument, must be ordered'
414+
' non-empty list from (1, 99), inclusive.'
415+
)
416+
cur_pct = 0
417+
for pct in percentiles:
418+
if pct > 99 or not pct > cur_pct:
411419
raise ValueError(
412-
'Invalid percentiles argument, must be ordered'
420+
'Invalid percentiles spec, must be ordered'
413421
' non-empty list from (1, 99), inclusive.'
414422
)
415-
cur_pct = 0
416-
for pct in percentiles:
417-
if pct > 99 or not pct > cur_pct:
418-
raise ValueError(
419-
'Invalid percentiles spec, must be ordered'
420-
' non-empty list from (1, 99), inclusive.'
421-
)
422-
cur_pct = pct
423-
percentiles_str = '='.join(
424-
['--percentiles', ','.join([str(x) for x in percentiles])]
423+
cur_pct = pct
424+
percentiles_str = (
425+
f"--percentiles= {','.join(str(x) for x in percentiles)}"
426+
)
427+
428+
if not isinstance(sig_figs, int) or sig_figs < 1 or sig_figs > 18:
429+
raise ValueError(
430+
'Keyword "sig_figs" must be an integer between 1 and 18,'
431+
' found {}'.format(sig_figs)
425432
)
426-
sig_figs_str = '--sig_figs=2'
427-
if sig_figs is not None:
428-
if not isinstance(sig_figs, int) or sig_figs < 1 or sig_figs > 18:
429-
raise ValueError(
430-
'Keyword "sig_figs" must be an integer between 1 and 18,'
431-
' found {}'.format(sig_figs)
432-
)
433-
csv_sig_figs = self._sig_figs or 6
434-
if sig_figs > csv_sig_figs:
435-
get_logger().warning(
436-
'Requesting %d significant digits of output, but CSV files'
437-
' only have %d digits of precision.',
438-
sig_figs,
439-
csv_sig_figs,
440-
)
441-
sig_figs_str = '--sig_figs=' + str(sig_figs)
433+
csv_sig_figs = self._sig_figs or 6
434+
if sig_figs > csv_sig_figs:
435+
get_logger().warning(
436+
'Requesting %d significant digits of output, but CSV files'
437+
' only have %d digits of precision.',
438+
sig_figs,
439+
csv_sig_figs,
440+
)
441+
sig_figs_str = f'--sig_figs={sig_figs}'
442442
cmd_path = os.path.join(
443443
cmdstan_path(), 'bin', 'stansummary' + EXTENSION
444444
)

0 commit comments

Comments
 (0)