Skip to content

Commit 9f33f12

Browse files
authored
Merge pull request #602 from stan-dev/issue/600-summary-report
Cleanup output from cmdstan/bin/stansummary
2 parents 26374f8 + 93aec5f commit 9f33f12

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,11 @@ def summary(
448448
) -> pd.DataFrame:
449449
"""
450450
Run cmdstan/bin/stansummary over all output CSV files, assemble
451-
summary into DataFrame object; first row contains summary statistics
452-
for total joint log probability `lp__`, remaining rows contain summary
451+
summary into DataFrame object. The first row contains statistics
452+
for the total joint log probability `lp__`, but is omitted when the
453+
Stan model has no parameters. The remaining rows contain summary
453454
statistics for all parameters, transformed parameters, and generated
454-
quantities variables listed in the order in which they were declared
455-
in the Stan program.
455+
quantities variables, in program declaration order.
456456
457457
:param percentiles: Ordered non-empty sequence of percentiles to report.
458458
Must be integers from (1, 99), inclusive. Defaults to
@@ -467,7 +467,6 @@ def summary(
467467
468468
:return: pandas.DataFrame
469469
"""
470-
471470
if len(percentiles) == 0:
472471
raise ValueError(
473472
'Invalid percentiles argument, must be ordered'
@@ -526,7 +525,14 @@ def summary(
526525
comment='#',
527526
float_precision='high',
528527
)
529-
mask = [x == 'lp__' or not x.endswith('__') for x in summary_data.index]
528+
mask = (
529+
[not x.endswith('__') for x in summary_data.index]
530+
if self._is_fixed_param
531+
else [
532+
x == 'lp__' or not x.endswith('__') for x in summary_data.index
533+
]
534+
)
535+
summary_data.index.name = None
530536
return summary_data[mask]
531537

532538
def diagnose(self) -> Optional[str]:

test/test_sample.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,8 @@ def test_fixed_param_unspecified(self):
598598
iter_sampling=100, show_progress=False
599599
)
600600
self.assertEqual(datagen_fit.step_size, None)
601+
summary = datagen_fit.summary()
602+
self.assertNotIn('lp__', list(summary.index))
601603

602604
exe_only = os.path.join(DATAFILES_PATH, 'exe_only')
603605
shutil.copyfile(datagen_model.exe_file, exe_only)
@@ -608,6 +610,8 @@ def test_fixed_param_unspecified(self):
608610
)
609611
self.assertEqual(datagen2_fit.chains, 4)
610612
self.assertEqual(datagen2_fit.step_size, None)
613+
summary = datagen2_fit.summary()
614+
self.assertNotIn('lp__', list(summary.index))
611615

612616
def test_bernoulli_file_with_space(self):
613617
self.test_bernoulli_good('bernoulli with space in name.stan')
@@ -743,11 +747,11 @@ def test_validate_good_run(self):
743747

744748
self.assertEqual(
745749
list(fit.draws_pd(vars=['theta', 'lp__']).columns),
746-
['theta', 'lp__']
750+
['theta', 'lp__'],
747751
)
748752
self.assertEqual(
749753
list(fit.draws_pd(vars=['lp__', 'theta']).columns),
750-
['lp__', 'theta']
754+
['lp__', 'theta'],
751755
)
752756

753757
summary = fit.summary()
@@ -756,6 +760,9 @@ def test_validate_good_run(self):
756760
self.assertIn('95%', list(summary.columns))
757761
self.assertNotIn('1%', list(summary.columns))
758762
self.assertNotIn('99%', list(summary.columns))
763+
self.assertEqual(summary.index.name, None)
764+
self.assertIn('lp__', list(summary.index))
765+
self.assertIn('theta', list(summary.index))
759766

760767
summary = fit.summary(percentiles=[1, 45, 99])
761768
self.assertIn('1%', list(summary.columns))

0 commit comments

Comments
 (0)