Skip to content

Commit d1aeceb

Browse files
authored
Merge pull request #791 from amas0/add-new-create-inits
Add new `create_inits()` methods to other stanfit classes
2 parents f4e233b + d6d5bd3 commit d1aeceb

File tree

10 files changed

+505
-179
lines changed

10 files changed

+505
-179
lines changed

cmdstanpy/stanfit/laplace.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Container for the result of running a laplace approximation.
2+
Container for the result of running a laplace approximation.
33
"""
44

55
from typing import (
@@ -52,6 +52,39 @@ def __init__(self, runset: RunSet, mode: CmdStanMLE) -> None:
5252
config = scan_generic_csv(runset.csv_files[0])
5353
self._metadata = InferenceMetadata(config)
5454

55+
def create_inits(
56+
self, seed: Optional[int] = None, chains: int = 4
57+
) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
58+
"""
59+
Create initial values for the parameters of the model
60+
by randomly selecting draws from the Laplace approximation.
61+
62+
:param seed: Used for random selection, defaults to None
63+
:param chains: Number of initial values to return, defaults to 4
64+
:return: The initial values for the parameters of the model.
65+
66+
If ``chains`` is 1, a dictionary is returned, otherwise a list
67+
of dictionaries is returned, in the format expected for the
68+
``inits`` argument of :meth:`CmdStanModel.sample`.
69+
"""
70+
self._assemble_draws()
71+
rng = np.random.default_rng(seed)
72+
idxs = rng.choice(self._draws.shape[0], size=chains, replace=False)
73+
if chains == 1:
74+
draw = self._draws[idxs[0]]
75+
return {
76+
name: var.extract_reshape(draw)
77+
for name, var in self._metadata.stan_vars.items()
78+
}
79+
else:
80+
return [
81+
{
82+
name: var.extract_reshape(self._draws[idx])
83+
for name, var in self._metadata.stan_vars.items()
84+
}
85+
for idx in idxs
86+
]
87+
5588
def _assemble_draws(self) -> None:
5689
if self._draws.shape != (0,):
5790
return

cmdstanpy/stanfit/mcmc.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,46 @@ def __init__(
105105
if not self._is_fixed_param:
106106
self._check_sampler_diagnostics()
107107

108+
def create_inits(
109+
self, seed: Optional[int] = None, chains: int = 4
110+
) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
111+
"""
112+
Create initial values for the parameters of the model by
113+
randomly selecting draws from the MCMC samples. If the samples
114+
contain draws from multiple chains, each draw will be from
115+
a different chain, if possible. Otherwise the chain is randomly
116+
selected.
117+
118+
:param seed: Used for random selection, defaults to None
119+
:param chains: Number of initial values to return, defaults to 4
120+
:return: The initial values for the parameters of the model.
121+
122+
If ``chains`` is 1, a dictionary is returned, otherwise a list
123+
of dictionaries is returned, in the format expected for the
124+
``inits`` argument of :meth:`CmdStanModel.sample`.
125+
"""
126+
self._assemble_draws()
127+
rng = np.random.default_rng(seed)
128+
n_draws, n_chains = self._draws.shape[:2]
129+
draw_idxs = rng.choice(n_draws, size=chains, replace=False)
130+
chain_idxs = rng.choice(
131+
n_chains, size=chains, replace=(n_chains <= chains)
132+
)
133+
if chains == 1:
134+
draw = self._draws[draw_idxs[0], chain_idxs[0]]
135+
return {
136+
name: var.extract_reshape(draw)
137+
for name, var in self._metadata.stan_vars.items()
138+
}
139+
else:
140+
return [
141+
{
142+
name: var.extract_reshape(self._draws[d, i])
143+
for name, var in self._metadata.stan_vars.items()
144+
}
145+
for d, i in zip(draw_idxs, chain_idxs)
146+
]
147+
108148
def __repr__(self) -> str:
109149
repr = 'CmdStanMCMC: model={} chains={}{}'.format(
110150
self.runset.model,
@@ -685,7 +725,7 @@ def draws_xr(
685725
)
686726
if inc_warmup and not self._save_warmup:
687727
get_logger().warning(
688-
"Draws from warmup iterations not available,"
728+
'Draws from warmup iterations not available,'
689729
' must run sampler with "save_warmup=True".'
690730
)
691731
if vars is None:

cmdstanpy/stanfit/mle.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,30 @@ def __init__(self, runset: RunSet) -> None:
3636
self._save_iterations: bool = optimize_args.save_iterations
3737
self._set_mle_attrs(runset.csv_files[0])
3838

39+
def create_inits(
40+
self, seed: Optional[int] = None, chains: int = 4
41+
) -> Dict[str, np.ndarray]:
42+
"""
43+
Create initial values for the parameters of the model
44+
from the MLE.
45+
46+
:param seed: Unused. Kept for compatibility with other
47+
create_inits methods.
48+
:param chains: Unused. Kept for compatibility with other
49+
create_inits methods.
50+
:return: The initial values for the parameters of the model.
51+
52+
Returns a dictionary of MLE estimates in the format expected
53+
for the ``inits`` argument of :meth:`CmdStanModel.sample`.
54+
When running multi-chain sampling, all chains will be initialized
55+
at the same points.
56+
"""
57+
# pylint: disable=unused-argument
58+
59+
return {
60+
name: np.array(val) for name, val in self.stan_variables().items()
61+
}
62+
3963
def __repr__(self) -> str:
4064
repr = 'CmdStanMLE: model={}{}'.format(
4165
self.runset.model, self.runset._args.method_args.compose(0, cmd=[])

cmdstanpy/stanfit/pathfinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def create_inits(
4545
4646
If ``chains`` is 1, a dictionary is returned, otherwise a list
4747
of dictionaries is returned, in the format expected for the
48-
``inits`` argument. of :meth:`CmdStanModel.sample`.
48+
``inits`` argument of :meth:`CmdStanModel.sample`.
4949
"""
5050
self._assemble_draws()
5151
rng = np.random.default_rng(seed)

cmdstanpy/stanfit/vb.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Container for the results of running autodiff variational inference"""
22

33
from collections import OrderedDict
4-
from typing import Dict, Optional, Tuple, Union
4+
from typing import Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import pandas as pd
@@ -30,6 +30,41 @@ def __init__(self, runset: RunSet) -> None:
3030
self.runset = runset
3131
self._set_variational_attrs(runset.csv_files[0])
3232

33+
def create_inits(
34+
self, seed: Optional[int] = None, chains: int = 4
35+
) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
36+
"""
37+
Create initial values for the parameters of the model
38+
by randomly selecting draws from the variational approximation
39+
draws.
40+
41+
:param seed: Used for random selection, defaults to None
42+
:param chains: Number of initial values to return, defaults to 4
43+
:return: The initial values for the parameters of the model.
44+
45+
If ``chains`` is 1, a dictionary is returned, otherwise a list
46+
of dictionaries is returned, in the format expected for the
47+
``inits`` argument of :meth:`CmdStanModel.sample`.
48+
"""
49+
rng = np.random.default_rng(seed)
50+
idxs = rng.choice(
51+
self.variational_sample.shape[0], size=chains, replace=False
52+
)
53+
if chains == 1:
54+
draw = self.variational_sample[idxs[0]]
55+
return {
56+
name: var.extract_reshape(draw)
57+
for name, var in self._metadata.stan_vars.items()
58+
}
59+
else:
60+
return [
61+
{
62+
name: var.extract_reshape(self.variational_sample[idx])
63+
for name, var in self._metadata.stan_vars.items()
64+
}
65+
for idx in idxs
66+
]
67+
3368
def __repr__(self) -> str:
3469
repr = 'CmdStanVB: model={}{}'.format(
3570
self.runset.model, self.runset._args.method_args.compose(0, cmd=[])

0 commit comments

Comments
 (0)