Skip to content

Commit 26374f8

Browse files
authored
Merge pull request #598 from stan-dev/update/column_order
Keep column order for dataframe
2 parents 5bc1040 + ef975dc commit 26374f8

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

cmdstanpy/stanfit/gq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def draws_pd(
268268
vars_list = [vars]
269269
else:
270270
vars_list = vars
271+
vars_list = list(dict.fromkeys(vars_list))
271272
if (
272273
inc_warmup
273274
and not self.mcmc_sample.metadata.cmdstan_config['save_warmup']
@@ -282,7 +283,7 @@ def draws_pd(
282283
gq_cols = []
283284
mcmc_vars = []
284285
if vars is not None:
285-
for var in set(vars_list):
286+
for var in vars_list:
286287
if var in self.metadata.stan_vars_cols:
287288
for idx in self.metadata.stan_vars_cols[var]:
288289
gq_cols.append(self.column_names[idx])
@@ -295,6 +296,7 @@ def draws_pd(
295296
raise ValueError('Unknown variable: {}'.format(var))
296297
else:
297298
gq_cols = list(self.column_names)
299+
vars_list = gq_cols
298300

299301
if inc_sample and mcmc_vars:
300302
if gq_cols:
@@ -311,7 +313,7 @@ def draws_pd(
311313
)[gq_cols],
312314
],
313315
axis='columns',
314-
)
316+
)[vars_list]
315317
else:
316318
return self.mcmc_sample.draws_pd(
317319
vars=mcmc_vars, inc_warmup=inc_warmup

cmdstanpy/stanfit/mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def draws_pd(
589589
self._assemble_draws()
590590
cols = []
591591
if vars is not None:
592-
for var in set(vars_list):
592+
for var in dict.fromkeys(vars_list):
593593
if (
594594
var not in self.metadata.method_vars_cols
595595
and var not in self.metadata.stan_vars_cols

test/test_generate_quantities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def test_from_csv_files(self):
8484
+ bern_gqs.draws_pd().shape[1],
8585
)
8686

87+
self.assertEqual(
88+
list(bern_gqs.draws_pd(vars=['y_rep']).columns),
89+
column_names,
90+
)
91+
8792
def test_from_csv_files_bad(self):
8893
# gq model
8994
stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')

test/test_sample.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,15 @@ def test_validate_good_run(self):
741741
self.assertEqual(fit.draws_pd(vars=['theta', 'lp__']).shape, (400, 2))
742742
self.assertEqual(fit.draws_pd(vars='theta').shape, (400, 1))
743743

744+
self.assertEqual(
745+
list(fit.draws_pd(vars=['theta', 'lp__']).columns),
746+
['theta', 'lp__']
747+
)
748+
self.assertEqual(
749+
list(fit.draws_pd(vars=['lp__', 'theta']).columns),
750+
['lp__', 'theta']
751+
)
752+
744753
summary = fit.summary()
745754
self.assertIn('5%', list(summary.columns))
746755
self.assertIn('50%', list(summary.columns))

0 commit comments

Comments
 (0)