Skip to content

Commit de2e73c

Browse files
authored
Merge pull request #741 from stan-dev/feat/738-pathfinder-threads
Add a num_threads helper argument to pathfinder()
2 parents 742b409 + 71d22e0 commit de2e73c

File tree

4 files changed

+59
-1
lines changed

4 files changed

+59
-1
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def validate(self) -> None:
930930
if not (
931931
isinstance(self.method_args, SamplerArgs)
932932
and self.method_args.num_chains > 1
933+
or isinstance(self.method_args, PathfinderArgs)
933934
):
934935
if not os.path.exists(self.inits):
935936
raise ValueError('no such file {}'.format(self.inits))

cmdstanpy/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,7 @@ def pathfinder(
15871587
refresh: Optional[int] = None,
15881588
time_fmt: str = "%Y%m%d%H%M%S",
15891589
timeout: Optional[float] = None,
1590+
num_threads: Optional[int] = None,
15901591
) -> CmdStanPathfinder:
15911592
"""
15921593
Run CmdStan's Pathfinder variational inference algorithm.
@@ -1689,6 +1690,10 @@ def pathfinder(
16891690
:param timeout: Duration at which Pathfinder times
16901691
out in seconds. Defaults to None.
16911692
1693+
:param num_threads: Number of threads to request for parallel execution.
1694+
A number other than ``1`` requires the model to have been compiled
1695+
with STAN_THREADS=True.
1696+
16921697
:return: A :class:`CmdStanPathfinder` object
16931698
16941699
References
@@ -1715,6 +1720,17 @@ def pathfinder(
17151720
"available for CmdStan versions 2.34 and later"
17161721
)
17171722

1723+
if num_threads is not None:
1724+
if (
1725+
num_threads != 1
1726+
and exe_info.get('STAN_THREADS', '').lower() != 'true'
1727+
):
1728+
raise ValueError(
1729+
"Model must be compiled with 'STAN_THREADS=true' to use"
1730+
" 'num_threads' argument"
1731+
)
1732+
os.environ['STAN_NUM_THREADS'] = str(num_threads)
1733+
17181734
if num_paths == 1:
17191735
if num_single_draws is None:
17201736
num_single_draws = draws

test/test_pathfinder.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Tests for the Pathfinder method.
33
"""
44

5+
import contextlib
6+
from io import StringIO
57
from pathlib import Path
68

79
import numpy as np
@@ -129,6 +131,26 @@ def test_pathfinder_init_sampling():
129131
assert fit.draws().shape == (1000, 4, 9)
130132

131133

134+
def test_inits_for_pathfinder():
135+
stan = DATAFILES_PATH / 'bernoulli.stan'
136+
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
137+
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
138+
bern_model.pathfinder(
139+
jdata, inits=[{"theta": 0.1}, {"theta": 0.9}], num_paths=2
140+
)
141+
142+
# second path is initialized too large!
143+
with contextlib.redirect_stdout(StringIO()) as captured:
144+
bern_model.pathfinder(
145+
jdata,
146+
inits=[{"theta": 0.1}, {"theta": 1.1}],
147+
num_paths=2,
148+
show_console=True,
149+
)
150+
151+
assert "Bounded variable is 1.1" in captured.getvalue()
152+
153+
132154
def test_pathfinder_no_psis():
133155
stan = DATAFILES_PATH / 'bernoulli.stan'
134156
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
@@ -152,3 +174,20 @@ def test_pathfinder_no_lp_calc():
152174
n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__']))
153175
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
154176
assert n_lp_nan > 3000 # but most are not
177+
178+
179+
def test_pathfinder_threads():
180+
stan = DATAFILES_PATH / 'bernoulli.stan'
181+
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
182+
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
183+
184+
bern_model.pathfinder(data=jdata, num_threads=1)
185+
186+
with pytest.raises(ValueError, match="STAN_THREADS"):
187+
bern_model.pathfinder(data=jdata, num_threads=4)
188+
189+
bern_model = cmdstanpy.CmdStanModel(
190+
stan_file=stan, cpp_options={'STAN_THREADS': True}, force_compile=True
191+
)
192+
pathfinder = bern_model.pathfinder(data=jdata, num_threads=4)
193+
assert pathfinder.draws().shape == (1000, 3)

test/test_sample.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
)
5656
def test_bernoulli_good(stanfile: str):
5757
stan = os.path.join(DATAFILES_PATH, stanfile)
58-
bern_model = CmdStanModel(stan_file=stan)
58+
bern_model = CmdStanModel(stan_file=stan, force_compile=True)
5959

6060
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
6161
bern_fit = bern_model.sample(
@@ -74,6 +74,8 @@ def test_bernoulli_good(stanfile: str):
7474

7575
for i in range(bern_fit.runset.chains):
7676
csv_file = bern_fit.runset.csv_files[i]
77+
# NB: This will fail if STAN_THREADS is enabled
78+
# due to sampling only producing 1 stdout file in that case
7779
stdout_file = bern_fit.runset.stdout_files[i]
7880
assert os.path.exists(csv_file)
7981
assert os.path.exists(stdout_file)

0 commit comments

Comments
 (0)