Skip to content

Commit 6edb7ad

Browse files
authored
Merge pull request #543 from stan-dev/stanvariables-dot
Support `.` syntax for Stan variable access
2 parents ee3692e + 5c38497 commit 6edb7ad

File tree

9 files changed

+173
-17
lines changed

9 files changed

+173
-17
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ def __repr__(self) -> str:
117117
# TODO - hamiltonian, profiling files
118118
return repr
119119

120+
def __getattr__(self, attr: str) -> np.ndarray:
121+
"""Synonymous with ``fit.stan_variable(attr)"""
122+
try:
123+
return self.stan_variable(attr)
124+
except ValueError as e:
125+
# pylint: disable=raise-missing-from
126+
raise AttributeError(*e.args)
127+
120128
@property
121129
def chains(self) -> int:
122130
"""Number of chains."""
@@ -619,7 +627,7 @@ def draws_xr(
619627

620628
def stan_variable(
621629
self,
622-
var: Optional[str] = None,
630+
var: str,
623631
inc_warmup: bool = False,
624632
) -> np.ndarray:
625633
"""
@@ -647,6 +655,9 @@ def stan_variable(
647655
and the sample consists of 4 chains with 1000 post-warmup draws,
648656
this function will return a numpy.ndarray with shape (4000,3,3).
649657
658+
This functionaltiy is also available via a shortcut using ``.`` -
659+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
660+
650661
:param var: variable name
651662
652663
:param inc_warmup: When ``True`` and the warmup draws are present in
@@ -660,10 +671,12 @@ def stan_variable(
660671
CmdStanVB.stan_variable
661672
CmdStanGQ.stan_variable
662673
"""
663-
if var is None:
664-
raise ValueError('No variable name specified.')
665674
if var not in self._metadata.stan_vars_dims:
666-
raise ValueError('Unknown variable name: {}'.format(var))
675+
raise ValueError(
676+
f'Unknown variable name: {var}\n'
677+
'Available variables are '
678+
+ ", ".join(self._metadata.stan_vars_dims)
679+
)
667680
if self._draws.shape == (0,):
668681
self._assemble_draws()
669682
draw1 = 0
@@ -767,6 +780,14 @@ def __repr__(self) -> str:
767780
)
768781
return repr
769782

783+
def __getattr__(self, attr: str) -> np.ndarray:
784+
"""Synonymous with ``fit.stan_variable(attr)"""
785+
try:
786+
return self.stan_variable(attr)
787+
except ValueError as e:
788+
# pylint: disable=raise-missing-from
789+
raise AttributeError(*e.args)
790+
770791
def _validate_csv_files(self) -> Dict[str, Any]:
771792
"""
772793
Checks that Stan CSV output files for all chains are consistent
@@ -1130,7 +1151,7 @@ def draws_xr(
11301151

11311152
def stan_variable(
11321153
self,
1133-
var: Optional[str] = None,
1154+
var: str,
11341155
inc_warmup: bool = False,
11351156
) -> np.ndarray:
11361157
"""
@@ -1158,6 +1179,9 @@ def stan_variable(
11581179
and the sample consists of 4 chains with 1000 post-warmup draws,
11591180
this function will return a numpy.ndarray with shape (4000,3,3).
11601181
1182+
This functionaltiy is also available via a shortcut using ``.`` -
1183+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
1184+
11611185
:param var: variable name
11621186
11631187
:param inc_warmup: When ``True`` and the warmup draws are present in
@@ -1171,12 +1195,14 @@ def stan_variable(
11711195
CmdStanMLE.stan_variable
11721196
CmdStanVB.stan_variable
11731197
"""
1174-
if var is None:
1175-
raise ValueError('No variable name specified.')
11761198
model_var_names = self.mcmc_sample.metadata.stan_vars_cols.keys()
11771199
gq_var_names = self.metadata.stan_vars_cols.keys()
11781200
if not (var in model_var_names or var in gq_var_names):
1179-
raise ValueError('Unknown variable name: {}'.format(var))
1201+
raise ValueError(
1202+
f'Unknown variable name: {var}\n'
1203+
'Available variables are '
1204+
+ ", ".join(model_var_names | gq_var_names)
1205+
)
11801206
if var not in gq_var_names:
11811207
return self.mcmc_sample.stan_variable(var, inc_warmup=inc_warmup)
11821208
else: # is gq variable

cmdstanpy/stanfit/mle.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def __repr__(self) -> str:
5050
repr = '{} optimization failed to converge.'.format(repr)
5151
return repr
5252

53+
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
54+
"""Synonymous with ``fit.stan_variable(attr)"""
55+
try:
56+
return self.stan_variable(attr)
57+
except ValueError as e:
58+
# pylint: disable=raise-missing-from
59+
raise AttributeError(*e.args)
60+
5361
def _set_mle_attrs(self, sample_csv_0: str) -> None:
5462
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
5563
self._metadata = InferenceMetadata(meta)
@@ -155,7 +163,7 @@ def optimized_params_dict(self) -> Dict[str, float]:
155163

156164
def stan_variable(
157165
self,
158-
var: Optional[str] = None,
166+
var: str,
159167
*,
160168
inc_iterations: bool = False,
161169
warn: bool = True,
@@ -165,6 +173,9 @@ def stan_variable(
165173
for the named Stan program variable where the dimensions of the
166174
numpy.ndarray match the shape of the Stan program variable.
167175
176+
This functionaltiy is also available via a shortcut using ``.`` -
177+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
178+
168179
:param var: variable name
169180
170181
:param inc_iterations: When ``True`` and the intermediate estimates
@@ -179,10 +190,12 @@ def stan_variable(
179190
CmdStanVB.stan_variable
180191
CmdStanGQ.stan_variable
181192
"""
182-
if var is None:
183-
raise ValueError('no variable name specified.')
184193
if var not in self._metadata.stan_vars_dims:
185-
raise ValueError('unknown variable name: {}'.format(var))
194+
raise ValueError(
195+
f'Unknown variable name: {var}\n'
196+
'Available variables are '
197+
+ ", ".join(self._metadata.stan_vars_dims)
198+
)
186199
if warn and inc_iterations and not self._save_iterations:
187200
get_logger().warning(
188201
'Intermediate iterations not saved to CSV output file. '

cmdstanpy/stanfit/vb.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def __repr__(self) -> str:
4141
# TODO - diagnostic, profiling files
4242
return repr
4343

44+
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
45+
"""Synonymous with ``fit.stan_variable(attr)"""
46+
try:
47+
return self.stan_variable(attr)
48+
except ValueError as e:
49+
# pylint: disable=raise-missing-from
50+
raise AttributeError(*e.args)
51+
4452
def _set_variational_attrs(self, sample_csv_0: str) -> None:
4553
meta = scan_variational_csv(sample_csv_0)
4654
self._metadata = InferenceMetadata(meta)
@@ -103,14 +111,15 @@ def metadata(self) -> InferenceMetadata:
103111
"""
104112
return self._metadata
105113

106-
def stan_variable(
107-
self, var: Optional[str] = None
108-
) -> Union[np.ndarray, float]:
114+
def stan_variable(self, var: str) -> Union[np.ndarray, float]:
109115
"""
110116
Return a numpy.ndarray which contains the estimates for the
111117
for the named Stan program variable where the dimensions of the
112118
numpy.ndarray match the shape of the Stan program variable.
113119
120+
This functionaltiy is also available via a shortcut using ``.`` -
121+
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
122+
114123
:param var: variable name
115124
116125
See Also
@@ -123,7 +132,11 @@ def stan_variable(
123132
if var is None:
124133
raise ValueError('No variable name specified.')
125134
if var not in self._metadata.stan_vars_dims:
126-
raise ValueError('Unknown variable name: {}'.format(var))
135+
raise ValueError(
136+
f'Unknown variable name: {var}\n'
137+
'Available variables are '
138+
+ ", ".join(self._metadata.stan_vars_dims)
139+
)
127140
col_idxs = list(self._metadata.stan_vars_cols[var])
128141
shape: Tuple[int, ...] = ()
129142
if len(col_idxs) > 1:

cmdstanpy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def scan_variational_csv(path: str) -> Dict[str, Any]:
730730
lineno += 1
731731
xs = line.split(',')
732732
variational_mean = [float(x) for x in xs]
733-
dict['variational_mean'] = variational_mean
733+
dict['variational_mean'] = np.array(variational_mean)
734734
dict['variational_sample'] = pd.read_csv(
735735
path,
736736
comment='#',

test/data/named_output.stan

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
data {
2+
int<lower=0> N;
3+
int<lower=0,upper=1> y[N];
4+
}
5+
parameters {
6+
real<lower=0,upper=1> theta;
7+
}
8+
model {
9+
theta ~ beta(1,1); // uniform prior on interval 0,1
10+
y ~ bernoulli(theta);
11+
}
12+
13+
generated quantities {
14+
// these should be accessible via .
15+
real a = 4.5;
16+
array[3] real b = {1, 2.5, 4.5};
17+
18+
// these should not override built in properties/funs
19+
real thin = 3.5;
20+
int draws = 0;
21+
int optimized_params_np = 0;
22+
int variational_params_np = 0;
23+
}

test/test_generate_quantities.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,26 @@ def test_complex_output(self):
444444
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
445445
)
446446

447+
def test_attrs(self):
448+
stan_bern = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
449+
model_bern = CmdStanModel(stan_file=stan_bern)
450+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
451+
fit_sampling = model_bern.sample(chains=1, iter_sampling=10, data=jdata)
452+
453+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
454+
model = CmdStanModel(stan_file=stan)
455+
fit = model.generate_quantities(data=jdata, mcmc_sample=fit_sampling)
456+
457+
self.assertEqual(fit.a[0], 4.5)
458+
self.assertEqual(fit.b.shape, (10, 3))
459+
self.assertEqual(fit.theta.shape, (10,))
460+
461+
fit.draws()
462+
self.assertEqual(fit.stan_variable('draws')[0], 0)
463+
464+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
465+
dummy = fit.c
466+
447467

448468
if __name__ == '__main__':
449469
unittest.main()

test/test_optimize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,24 @@ def test_complex_output(self):
609609
# make sure the name 'imag' isn't magic
610610
self.assertEqual(fit.stan_variable('imag').shape, (2,))
611611

612+
def test_attrs(self):
613+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
614+
model = CmdStanModel(stan_file=stan)
615+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
616+
fit = model.optimize(data=jdata)
617+
618+
self.assertEqual(fit.a, 4.5)
619+
self.assertEqual(fit.b.shape, (3,))
620+
self.assertIsInstance(fit.theta, float)
621+
622+
self.assertEqual(fit.stan_variable('thin'), 3.5)
623+
624+
self.assertIsInstance(fit.optimized_params_np, np.ndarray)
625+
self.assertEqual(fit.stan_variable('optimized_params_np'), 0)
626+
627+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
628+
dummy = fit.c
629+
612630

613631
if __name__ == '__main__':
614632
unittest.main()

test/test_sample.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,25 @@ def test_complex_output(self):
17761776
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
17771777
)
17781778

1779+
def test_attrs(self):
1780+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
1781+
model = CmdStanModel(stan_file=stan)
1782+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
1783+
fit = model.sample(chains=1, iter_sampling=10, data=jdata)
1784+
1785+
self.assertEqual(fit.a[0], 4.5)
1786+
self.assertEqual(fit.b.shape, (10, 3))
1787+
self.assertEqual(fit.theta.shape, (10,))
1788+
1789+
self.assertEqual(fit.thin, 1)
1790+
self.assertEqual(fit.stan_variable('thin')[0], 3.5)
1791+
1792+
fit.draws()
1793+
self.assertEqual(fit.stan_variable('draws')[0], 0)
1794+
1795+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
1796+
dummy = fit.c
1797+
17791798

17801799
if __name__ == '__main__':
17811800
unittest.main()

test/test_variational.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88
from math import fabs
99

10+
import numpy as np
1011
from testfixtures import LogCapture
1112

1213
from cmdstanpy.cmdstan_args import CmdStanArgs, VariationalArgs
@@ -264,6 +265,29 @@ def test_complex_output(self):
264265
# make sure the name 'imag' isn't magic
265266
self.assertEqual(fit.stan_variable('imag').shape, (2,))
266267

268+
def test_attrs(self):
269+
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
270+
model = CmdStanModel(stan_file=stan)
271+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
272+
fit = model.variational(
273+
data=jdata,
274+
require_converged=False,
275+
seed=12345,
276+
algorithm='meanfield',
277+
)
278+
279+
self.assertEqual(fit.a, 4.5)
280+
self.assertEqual(fit.b.shape, (3,))
281+
self.assertIsInstance(fit.theta, float)
282+
283+
self.assertEqual(fit.stan_variable('thin'), 3.5)
284+
285+
self.assertIsInstance(fit.variational_params_np, np.ndarray)
286+
self.assertEqual(fit.stan_variable('variational_params_np'), 0)
287+
288+
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
289+
dummy = fit.c
290+
267291

268292
if __name__ == '__main__':
269293
unittest.main()

0 commit comments

Comments
 (0)