Skip to content

Commit dc1b939

Browse files
committed
Add new Pathfinder arguments
1 parent 078c43a commit dc1b939

File tree

5 files changed

+90
-8
lines changed

5 files changed

+90
-8
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,8 @@ def __init__(
541541
num_draws: Optional[int] = None,
542542
num_elbo_draws: Optional[int] = None,
543543
save_single_paths: bool = False,
544+
psis_resample: bool = True,
545+
calculate_lp: bool = True,
544546
) -> None:
545547
self.init_alpha = init_alpha
546548
self.tol_obj = tol_obj
@@ -557,6 +559,8 @@ def __init__(
557559
self.num_elbo_draws = num_elbo_draws
558560

559561
self.save_single_paths = save_single_paths
562+
self.psis_resample = psis_resample
563+
self.calculate_lp = calculate_lp
560564

561565
def validate(self, _chains: Optional[int] = None) -> None:
562566
"""
@@ -609,6 +613,12 @@ def compose(self, _idx: int, cmd: List[str]) -> List[str]:
609613
if self.save_single_paths:
610614
cmd.append('save_single_paths=1')
611615

616+
if not self.psis_resample:
617+
cmd.append('psis_resample=0')
618+
619+
if not self.calculate_lp:
620+
cmd.append('calculate_lp=0')
621+
612622
return cmd
613623

614624

cmdstanpy/model.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,8 @@ def pathfinder(
16121612
draws: Optional[int] = None,
16131613
num_single_draws: Optional[int] = None,
16141614
num_elbo_draws: Optional[int] = None,
1615+
psis_resample: bool = True,
1616+
calculate_lp: bool = True,
16151617
# arguments standard to all methods
16161618
seed: Optional[int] = None,
16171619
inits: Union[Dict[str, float], float, str, os.PathLike, None] = None,
@@ -1645,6 +1647,14 @@ def pathfinder(
16451647
16461648
:param num_elbo_draws: Number of Monte Carlo draws to evaluate ELBO.
16471649
1650+
:param psis_resample: Whether or not to use Pareto Smoothed Importance
1651+
Sampling on the result of the individual Pathfinders. If False, the
1652+
result contains the draws from each path.
1653+
1654+
:param calculate_lp: Whether or not to calculate the log probability
1655+
for approximate draws. If False, this also implies that
1656+
``psis_resample`` will be set to False.
1657+
16481658
:param seed: The seed for random number generator. Must be an integer
16491659
between 0 and 2^32 - 1. If unspecified,
16501660
:func:`numpy.random.default_rng` is used to generate a seed.
@@ -1726,12 +1736,22 @@ def pathfinder(
17261736
Research, 23(306), 1–49. Retrieved from
17271737
http://jmlr.org/papers/v23/21-0889.html
17281738
"""
1729-
if cmdstan_version_before(2, 33, self.exe_info()):
1739+
1740+
exe_info = self.exe_info()
1741+
if cmdstan_version_before(2, 33, exe_info):
17301742
raise ValueError(
17311743
"Method 'pathfinder' not available for CmdStan versions "
17321744
"before 2.33"
17331745
)
17341746

1747+
if (not psis_resample or not calculate_lp) and cmdstan_version_before(
1748+
2, 34, exe_info
1749+
):
1750+
raise ValueError(
1751+
"Arguments 'psis_resample' and 'calculate_lp' are only "
1752+
"available for CmdStan versions 2.34 and later"
1753+
)
1754+
17351755
if num_paths == 1:
17361756
if num_single_draws is None:
17371757
num_single_draws = draws
@@ -1754,6 +1774,8 @@ def pathfinder(
17541774
max_lbfgs_iters=max_lbfgs_iters,
17551775
num_draws=num_single_draws,
17561776
num_elbo_draws=num_elbo_draws,
1777+
psis_resample=psis_resample,
1778+
calculate_lp=calculate_lp,
17571779
)
17581780

17591781
with temp_single_json(data) as _data, temp_inits(inits) as _inits:

cmdstanpy/stanfit/pathfinder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ def column_names(self) -> Tuple[str, ...]:
206206
"""
207207
return self._metadata.cmdstan_config['column_names'] # type: ignore
208208

209+
@property
210+
def is_resampled(self) -> bool:
211+
"""
212+
Returns True if the draws were resampled from several Pathfinder
213+
approximations, False otherwise.
214+
"""
215+
return ( # type: ignore
216+
self._metadata.cmdstan_config.get("num_paths", 4) > 1
217+
and self._metadata.cmdstan_config.get('psis_resample', 1) == 1
218+
and self._metadata.cmdstan_config.get('calculate_lp', 1) == 1
219+
)
220+
209221
def save_csvfiles(self, dir: Optional[str] = None) -> None:
210222
"""
211223
Move output CSV files to specified directory. If files were

test/test_log_prob.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,21 @@
2121
BERN_BASENAME = 'bernoulli'
2222

2323

24-
@pytest.mark.parametrize("sig_figs, expected, expected_unadjusted", [
25-
(11, ["-7.0214667713","-1.188472607"], ["-5.5395901199", "-1.4903938392"]),
26-
(3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]),
27-
(None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"])
28-
])
29-
def test_lp_good(sig_figs: Optional[int], expected: List[str],
30-
expected_unadjusted: List[str]) -> None:
24+
@pytest.mark.parametrize(
25+
"sig_figs, expected, expected_unadjusted",
26+
[
27+
(
28+
11,
29+
["-7.0214667713", "-1.188472607"],
30+
["-5.5395901199", "-1.4903938392"],
31+
),
32+
(3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]),
33+
(None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]),
34+
],
35+
)
36+
def test_lp_good(
37+
sig_figs: Optional[int], expected: List[str], expected_unadjusted: List[str]
38+
) -> None:
3139
model = CmdStanModel(stan_file=BERN_STAN)
3240
params = {"theta": 0.34903938392023830482}
3341
out = model.log_prob(params, data=BERN_DATA, sig_figs=sig_figs)

test/test_pathfinder.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pathlib import Path
66

7+
import numpy as np
78
import pytest
89

910
import cmdstanpy
@@ -31,6 +32,8 @@ def test_pathfinder_outputs():
3132
assert theta.shape == (draws,)
3233
assert 0.23 < theta.mean() < 0.27
3334

35+
assert pathfinder.is_resampled
36+
3437
assert pathfinder.draws().shape == (draws, 3)
3538

3639

@@ -58,6 +61,8 @@ def test_single_pathfinder():
5861
draws=draws,
5962
)
6063

64+
assert not pathfinder.is_resampled
65+
6166
theta = pathfinder.theta
6267
assert theta.shape == (draws,)
6368

@@ -122,3 +127,28 @@ def test_pathfinder_init_sampling():
122127

123128
assert fit.chains == 4
124129
assert fit.draws().shape == (1000, 4, 9)
130+
131+
132+
def test_pathfinder_no_psis():
133+
stan = DATAFILES_PATH / 'bernoulli.stan'
134+
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
135+
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
136+
137+
pathfinder = bern_model.pathfinder(data=jdata, psis_resample=False)
138+
139+
assert not pathfinder.is_resampled
140+
assert pathfinder.draws().shape == (4000, 3)
141+
142+
143+
def test_pathfinder_no_lp_calc():
144+
stan = DATAFILES_PATH / 'bernoulli.stan'
145+
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
146+
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
147+
148+
pathfinder = bern_model.pathfinder(data=jdata, calculate_lp=False)
149+
150+
assert not pathfinder.is_resampled
151+
assert pathfinder.draws().shape == (4000, 3)
152+
n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__']))
153+
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
154+
assert n_lp_nan > 3000 # but most are not

0 commit comments

Comments
 (0)