Skip to content

Commit f072187

Browse files
committed
Add timeout parameter.
1 parent b43587d commit f072187

File tree

7 files changed

+86
-2
lines changed

7 files changed

+86
-2
lines changed

cmdstanpy/model.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from io import StringIO
1414
from multiprocessing import cpu_count
1515
from pathlib import Path
16+
import threading
1617
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
1718

1819
import ujson as json
@@ -593,6 +594,7 @@ def optimize(
593594
show_console: bool = False,
594595
refresh: Optional[int] = None,
595596
time_fmt: str = "%Y%m%d%H%M%S",
597+
timeout: Optional[float] = None,
596598
) -> CmdStanMLE:
597599
"""
598600
Run the specified CmdStan optimize algorithm to produce a
@@ -692,6 +694,8 @@ def optimize(
692694
:meth:`~datetime.datetime.strftime` to decide the file names for
693695
output CSVs. Defaults to "%Y%m%d%H%M%S"
694696
697+
:param timeout: Duration at which optimization times out in seconds.
698+
695699
:return: CmdStanMLE object
696700
"""
697701
optimize_args = OptimizeArgs(
@@ -723,7 +727,8 @@ def optimize(
723727
)
724728
dummy_chain_id = 0
725729
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
726-
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
730+
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console,
731+
timeout=timeout)
727732

728733
if not runset._check_retcodes():
729734
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
@@ -767,6 +772,7 @@ def sample(
767772
show_console: bool = False,
768773
refresh: Optional[int] = None,
769774
time_fmt: str = "%Y%m%d%H%M%S",
775+
timeout: Optional[float] = None,
770776
*,
771777
force_one_process_per_chain: Optional[bool] = None,
772778
) -> CmdStanMCMC:
@@ -964,6 +970,8 @@ def sample(
964970
model was compiled with STAN_THREADS=True, and utilize the
965971
parallel chain functionality if those conditions are met.
966972
973+
:param timeout: Duration at which sampling times out in seconds.
974+
967975
:return: CmdStanMCMC object
968976
"""
969977
if fixed_param is None:
@@ -1139,6 +1147,7 @@ def sample(
11391147
show_progress=show_progress,
11401148
show_console=show_console,
11411149
progress_hook=progress_hook,
1150+
timeout=timeout,
11421151
)
11431152
if show_progress and progress_hook is not None:
11441153
progress_hook("Done", -1) # -1 == all chains finished
@@ -1209,6 +1218,7 @@ def generate_quantities(
12091218
show_console: bool = False,
12101219
refresh: Optional[int] = None,
12111220
time_fmt: str = "%Y%m%d%H%M%S",
1221+
timeout: Optional[float] = None,
12121222
) -> CmdStanGQ:
12131223
"""
12141224
Run CmdStan's generate_quantities method which runs the generated
@@ -1267,6 +1277,8 @@ def generate_quantities(
12671277
:meth:`~datetime.datetime.strftime` to decide the file names for
12681278
output CSVs. Defaults to "%Y%m%d%H%M%S"
12691279
1280+
:param timeout: Duration at which generation times out in seconds.
1281+
12701282
:return: CmdStanGQ object
12711283
"""
12721284
if isinstance(mcmc_sample, CmdStanMCMC):
@@ -1329,6 +1341,7 @@ def generate_quantities(
13291341
runset,
13301342
i,
13311343
show_console=show_console,
1344+
timeout=timeout,
13321345
)
13331346

13341347
errors = runset.get_err_msgs()
@@ -1366,6 +1379,7 @@ def variational(
13661379
show_console: bool = False,
13671380
refresh: Optional[int] = None,
13681381
time_fmt: str = "%Y%m%d%H%M%S",
1382+
timeout: Optional[float] = None,
13691383
) -> CmdStanVB:
13701384
"""
13711385
Run CmdStan's variational inference algorithm to approximate
@@ -1458,6 +1472,9 @@ def variational(
14581472
:meth:`~datetime.datetime.strftime` to decide the file names for
14591473
output CSVs. Defaults to "%Y%m%d%H%M%S"
14601474
1475+
:param timeout: Duration at which variational Bayesian inference times
1476+
out in seconds.
1477+
14611478
:return: CmdStanVB object
14621479
"""
14631480
variational_args = VariationalArgs(
@@ -1491,7 +1508,8 @@ def variational(
14911508

14921509
dummy_chain_id = 0
14931510
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
1494-
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
1511+
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console,
1512+
timeout=timeout)
14951513

14961514
# treat failure to converge as failure
14971515
transcript_file = runset.stdout_files[dummy_chain_id]
@@ -1541,6 +1559,7 @@ def _run_cmdstan(
15411559
show_progress: bool = False,
15421560
show_console: bool = False,
15431561
progress_hook: Optional[Callable[[str, int], None]] = None,
1562+
timeout: Optional[float] = None,
15441563
) -> None:
15451564
"""
15461565
Helper function which encapsulates call to CmdStan.
@@ -1577,6 +1596,12 @@ def _run_cmdstan(
15771596
env=os.environ,
15781597
universal_newlines=True,
15791598
)
1599+
if timeout:
1600+
timer = threading.Timer(timeout, proc.terminate)
1601+
timer.setDaemon(True)
1602+
timer.start()
1603+
else:
1604+
timer = None
15801605
while proc.poll() is None:
15811606
if proc.stdout is not None:
15821607
line = proc.stdout.readline()
@@ -1589,6 +1614,8 @@ def _run_cmdstan(
15891614

15901615
stdout, _ = proc.communicate()
15911616
retcode = proc.returncode
1617+
if timer and retcode == -15:
1618+
retcode = 60
15921619
runset._set_retcode(idx, retcode)
15931620

15941621
if stdout:

cmdstanpy/stanfit/runset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def get_err_msgs(self) -> str:
234234
"""Checks console messages for each CmdStan run."""
235235
msgs = []
236236
for i in range(self._num_procs):
237+
if self._retcodes[i] == 60:
238+
msgs.append("processing timed out")
237239
if (
238240
os.path.exists(self._stdout_files[i])
239241
and os.stat(self._stdout_files[i]).st_size > 0

test/data/timeout.stan

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
data {
2+
// Indicator for endless looping.
3+
int loop;
4+
}
5+
6+
transformed data {
7+
// Maybe loop forever so the model times out.
8+
real y = 1;
9+
while(loop && y) {
10+
y += 1;
11+
}
12+
}
13+
14+
parameters {
15+
real x;
16+
}
17+
18+
model {
19+
// A nice model so we can get a fit for the `generated_quantities` call.
20+
x ~ normal(0, 1);
21+
}

test/test_generate_quantities.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,16 @@ def test_attrs(self):
476476
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
477477
dummy = fit.c
478478

479+
def test_timeout(self):
480+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
481+
timeout_model = CmdStanModel(stan_file=stan)
482+
fit = timeout_model.sample(data={'loop': 0}, chains=1, iter_sampling=10)
483+
self.assertRaisesRegex(
484+
RuntimeError, 'processing timed out',
485+
timeout_model.generate_quantities, timeout=0.1,
486+
mcmc_sample=fit, data={'loop': 1},
487+
)
488+
479489

480490
if __name__ == '__main__':
481491
unittest.main()

test/test_optimize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,14 @@ def test_attrs(self):
634634
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
635635
dummy = fit.c
636636

637+
def test_timeout(self):
638+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
639+
timeout_model = CmdStanModel(stan_file=stan)
640+
self.assertRaisesRegex(
641+
RuntimeError, 'processing timed out', timeout_model.optimize,
642+
data={'loop': 1}, timeout=0.1,
643+
)
644+
637645

638646
if __name__ == '__main__':
639647
unittest.main()

test/test_sample.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,14 @@ def test_diagnostics(self):
19131913
self.assertEqual(fit.max_treedepths, None)
19141914
self.assertEqual(fit.divergences, None)
19151915

1916+
def test_timeout(self):
1917+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
1918+
timeout_model = CmdStanModel(stan_file=stan)
1919+
self.assertRaisesRegex(
1920+
RuntimeError, 'processing timed out', timeout_model.sample,
1921+
timeout=0.1, chains=1, data={'loop': 1},
1922+
)
1923+
19161924

19171925
if __name__ == '__main__':
19181926
unittest.main()

test/test_variational.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,14 @@ def test_attrs(self):
291291
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
292292
dummy = fit.c
293293

294+
def test_timeout(self):
295+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
296+
timeout_model = CmdStanModel(stan_file=stan)
297+
self.assertRaisesRegex(
298+
RuntimeError, 'processing timed out', timeout_model.variational,
299+
timeout=0.1, data={'loop': 1}, show_console=True,
300+
)
301+
294302

295303
if __name__ == '__main__':
296304
unittest.main()

0 commit comments

Comments
 (0)