Skip to content

Commit c0338c9

Browse files
committed
Add a num_threads helper argument to pathfinder()
1 parent b420952 commit c0338c9

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

cmdstanpy/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,7 @@ def pathfinder(
16351635
refresh: Optional[int] = None,
16361636
time_fmt: str = "%Y%m%d%H%M%S",
16371637
timeout: Optional[float] = None,
1638+
num_threads: Optional[int] = None,
16381639
) -> CmdStanPathfinder:
16391640
"""
16401641
Run CmdStan's Pathfinder variational inference algorithm.
@@ -1737,6 +1738,10 @@ def pathfinder(
17371738
:param timeout: Duration at which Pathfinder times
17381739
out in seconds. Defaults to None.
17391740
1741+
:param num_threads: Number of threads to request for parallel execution.
1742+
A number other than ``1`` requires the model to have been compiled
1743+
with STAN_THREADS=True.
1744+
17401745
:return: A :class:`CmdStanPathfinder` object
17411746
17421747
References
@@ -1763,6 +1768,17 @@ def pathfinder(
17631768
"available for CmdStan versions 2.34 and later"
17641769
)
17651770

1771+
if num_threads is not None:
1772+
if (
1773+
num_threads != 1
1774+
and exe_info.get('STAN_THREADS', '').lower() != 'true'
1775+
):
1776+
raise ValueError(
1777+
"Model must be compiled with 'STAN_THREADS=true' to use"
1778+
" 'num_threads' argument"
1779+
)
1780+
os.environ['STAN_NUM_THREADS'] = str(num_threads)
1781+
17661782
if num_paths == 1:
17671783
if num_single_draws is None:
17681784
num_single_draws = draws

test/test_pathfinder.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,20 @@ def test_pathfinder_no_lp_calc():
152152
n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__']))
153153
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
154154
assert n_lp_nan > 3000 # but most are not
155+
156+
157+
def test_pathfinder_threads():
158+
stan = DATAFILES_PATH / 'bernoulli.stan'
159+
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
160+
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
161+
162+
bern_model.pathfinder(data=jdata, num_threads=1)
163+
164+
with pytest.raises(ValueError, match="STAN_THREADS"):
165+
bern_model.pathfinder(data=jdata, num_threads=4)
166+
167+
bern_model = cmdstanpy.CmdStanModel(
168+
stan_file=stan, cpp_options={'STAN_THREADS': True}, force_compile=True
169+
)
170+
pathfinder = bern_model.pathfinder(data=jdata, num_threads=4)
171+
assert pathfinder.draws().shape == (1000, 3)

0 commit comments

Comments
 (0)