88
99from cmdstanpy .cmdstan_args import Method
1010from cmdstanpy .utils import stancsv
11- from cmdstanpy .utils .logging import get_logger
1211
1312from .metadata import InferenceMetadata
1413from .runset import RunSet
@@ -100,7 +99,7 @@ def __repr__(self) -> str:
10099 # TODO - diagnostic, profiling files
101100 return repr
102101
103- def __getattr__ (self , attr : str ) -> Union [ np .ndarray , float ] :
102+ def __getattr__ (self , attr : str ) -> np .ndarray :
104103 """Synonymous with ``fit.stan_variable(attr)"""
105104 if attr .startswith ("_" ):
106105 raise AttributeError (f"Unknown variable name { attr } " )
@@ -163,9 +162,7 @@ def metadata(self) -> InferenceMetadata:
163162 """
164163 return self ._metadata
165164
166- def stan_variable (
167- self , var : str , * , mean : Optional [bool ] = None
168- ) -> Union [np .ndarray , float ]:
165+ def stan_variable (self , var : str , * , mean : bool = False ) -> np .ndarray :
169166 """
170167 Return a numpy.ndarray which contains the estimates for the
171168 for the named Stan program variable where the dimensions of the
@@ -188,8 +185,7 @@ def stan_variable(
188185 :param var: variable name
189186
190187 :param mean: if True, return the variational mean. Otherwise,
191- return the variational sample. The default behavior will
192- change in a future release to return the variational sample.
188+ return the variational sample. Defaults to False.
193189
194190 See Also
195191 --------
@@ -200,16 +196,7 @@ def stan_variable(
200196 CmdStanGQ.stan_variable
201197 CmdStanLaplace.stan_variable
202198 """
203- # TODO(2.0): remove None case, make default `False`
204- if mean is None :
205- get_logger ().warning (
206- "The default behavior of CmdStanVB.stan_variable() "
207- "will change in a future release to return the "
208- "variational sample, rather than the mean.\n "
209- "To maintain the current behavior, pass the argument "
210- "mean=True"
211- )
212- mean = True
199+
213200 if mean :
214201 draws = self ._variational_mean
215202 else :
@@ -219,16 +206,7 @@ def stan_variable(
219206 out : np .ndarray = self ._metadata .stan_vars [var ].extract_reshape (
220207 draws
221208 )
222- # TODO(2.0): remove
223- if out .shape == () or out .shape == (1 ,):
224- if mean :
225- get_logger ().warning (
226- "The default behavior of "
227- "CmdStanVB.stan_variable(mean=True) will change in a "
228- "future release to always return a numpy.ndarray, even "
229- "for scalar variables."
230- )
231- return out .item () # type: ignore
209+
232210 return out
233211 except KeyError :
234212 # pylint: disable=raise-missing-from
@@ -238,9 +216,7 @@ def stan_variable(
238216 + ", " .join (self ._metadata .stan_vars .keys ())
239217 )
240218
241- def stan_variables (
242- self , * , mean : Optional [bool ] = None
243- ) -> dict [str , Union [np .ndarray , float ]]:
219+ def stan_variables (self , * , mean : bool = False ) -> dict [str , np .ndarray ]:
244220 """
245221 Return a dictionary mapping Stan program variables names
246222 to the corresponding numpy.ndarray containing the inferred values.
0 commit comments