Skip to content

Commit 3add172

Browse files
committed
Remove deprecated stan_variable behavior for advi
1 parent f6439fb commit 3add172

File tree

2 files changed

+11
-41
lines changed

2 files changed

+11
-41
lines changed

cmdstanpy/stanfit/vb.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from cmdstanpy.cmdstan_args import Method
1010
from cmdstanpy.utils import stancsv
11-
from cmdstanpy.utils.logging import get_logger
1211

1312
from .metadata import InferenceMetadata
1413
from .runset import RunSet
@@ -100,7 +99,7 @@ def __repr__(self) -> str:
10099
# TODO - diagnostic, profiling files
101100
return repr
102101

103-
def __getattr__(self, attr: str) -> Union[np.ndarray, float]:
102+
def __getattr__(self, attr: str) -> np.ndarray:
104103
"""Synonymous with ``fit.stan_variable(attr)"""
105104
if attr.startswith("_"):
106105
raise AttributeError(f"Unknown variable name {attr}")
@@ -163,9 +162,7 @@ def metadata(self) -> InferenceMetadata:
163162
"""
164163
return self._metadata
165164

166-
def stan_variable(
167-
self, var: str, *, mean: Optional[bool] = None
168-
) -> Union[np.ndarray, float]:
165+
def stan_variable(self, var: str, *, mean: bool = False) -> np.ndarray:
169166
"""
170167
Return a numpy.ndarray which contains the estimates for the
171168
for the named Stan program variable where the dimensions of the
@@ -188,8 +185,7 @@ def stan_variable(
188185
:param var: variable name
189186
190187
:param mean: if True, return the variational mean. Otherwise,
191-
return the variational sample. The default behavior will
192-
change in a future release to return the variational sample.
188+
return the variational sample. Defaults to False.
193189
194190
See Also
195191
--------
@@ -200,16 +196,7 @@ def stan_variable(
200196
CmdStanGQ.stan_variable
201197
CmdStanLaplace.stan_variable
202198
"""
203-
# TODO(2.0): remove None case, make default `False`
204-
if mean is None:
205-
get_logger().warning(
206-
"The default behavior of CmdStanVB.stan_variable() "
207-
"will change in a future release to return the "
208-
"variational sample, rather than the mean.\n"
209-
"To maintain the current behavior, pass the argument "
210-
"mean=True"
211-
)
212-
mean = True
199+
213200
if mean:
214201
draws = self._variational_mean
215202
else:
@@ -219,16 +206,7 @@ def stan_variable(
219206
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(
220207
draws
221208
)
222-
# TODO(2.0): remove
223-
if out.shape == () or out.shape == (1,):
224-
if mean:
225-
get_logger().warning(
226-
"The default behavior of "
227-
"CmdStanVB.stan_variable(mean=True) will change in a "
228-
"future release to always return a numpy.ndarray, even "
229-
"for scalar variables."
230-
)
231-
return out.item() # type: ignore
209+
232210
return out
233211
except KeyError:
234212
# pylint: disable=raise-missing-from
@@ -238,9 +216,7 @@ def stan_variable(
238216
+ ", ".join(self._metadata.stan_vars.keys())
239217
)
240218

241-
def stan_variables(
242-
self, *, mean: Optional[bool] = None
243-
) -> dict[str, Union[np.ndarray, float]]:
219+
def stan_variables(self, *, mean: bool = False) -> dict[str, np.ndarray]:
244220
"""
245221
Return a dictionary mapping Stan program variables names
246222
to the corresponding numpy.ndarray containing the inferred values.

test/test_variational.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,8 @@ def test_complex_output() -> None:
282282
)
283283

284284
assert fit.stan_variable('zs', mean=False).shape == (1000, 2, 3)
285-
# TODO(2.0): change
286-
np.testing.assert_equal(fit.z, 3 + 4j)
287-
# np.testing.assert_equal(fit.z, np.repeat(3 + 4j, 1000))
285+
np.testing.assert_equal(fit.stan_variable("z", mean=True), 3 + 4j)
286+
np.testing.assert_equal(fit.z, np.repeat(3 + 4j, 1000))
288287

289288
np.testing.assert_allclose(
290289
fit.stan_variable('zs', mean=False)[0],
@@ -309,14 +308,9 @@ def test_attrs() -> None:
309308
algorithm='meanfield',
310309
)
311310

312-
# TODO(2.0): swap tests
313-
np.testing.assert_equal(fit.a, 4.5)
314-
assert fit.b.shape == (3,)
315-
assert isinstance(fit.theta, float)
316-
317-
# np.testing.assert_equal(fit.a, np.repeat(4.5, 1000))
318-
# assert fit.b.shape == (1000, 3)
319-
# assert fit.theta.shape == (1000,)
311+
np.testing.assert_equal(fit.a, np.repeat(4.5, 1000))
312+
assert fit.b.shape == (1000, 3)
313+
assert fit.theta.shape == (1000,)
320314

321315
assert fit.stan_variable('thin', mean=True) == 3.5
322316

0 commit comments

Comments
 (0)