Skip to content

Commit c2bab85

Browse files
authored
Merge pull request #621 from tillahoffmann/timeout
Add timeout parameter.
2 parents 1219274 + a71a36d commit c2bab85

File tree

7 files changed

+112
-5
lines changed

7 files changed

+112
-5
lines changed

cmdstanpy/model.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from io import StringIO
1414
from multiprocessing import cpu_count
1515
from pathlib import Path
16+
import threading
1617
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
1718

1819
import ujson as json
@@ -568,6 +569,7 @@ def optimize(
568569
show_console: bool = False,
569570
refresh: Optional[int] = None,
570571
time_fmt: str = "%Y%m%d%H%M%S",
572+
timeout: Optional[float] = None,
571573
) -> CmdStanMLE:
572574
"""
573575
Run the specified CmdStan optimize algorithm to produce a
@@ -667,6 +669,8 @@ def optimize(
667669
:meth:`~datetime.datetime.strftime` to decide the file names for
668670
output CSVs. Defaults to "%Y%m%d%H%M%S"
669671
672+
:param timeout: Duration at which optimization times out in seconds.
673+
670674
:return: CmdStanMLE object
671675
"""
672676
optimize_args = OptimizeArgs(
@@ -698,7 +702,13 @@ def optimize(
698702
)
699703
dummy_chain_id = 0
700704
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
701-
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
705+
self._run_cmdstan(
706+
runset,
707+
dummy_chain_id,
708+
show_console=show_console,
709+
timeout=timeout,
710+
)
711+
runset.raise_for_timeouts()
702712

703713
if not runset._check_retcodes():
704714
msg = "Error during optimization! Command '{}' failed: {}".format(
@@ -744,6 +754,7 @@ def sample(
744754
show_console: bool = False,
745755
refresh: Optional[int] = None,
746756
time_fmt: str = "%Y%m%d%H%M%S",
757+
timeout: Optional[float] = None,
747758
*,
748759
force_one_process_per_chain: Optional[bool] = None,
749760
) -> CmdStanMCMC:
@@ -941,6 +952,8 @@ def sample(
941952
model was compiled with STAN_THREADS=True, and utilize the
942953
parallel chain functionality if those conditions are met.
943954
955+
:param timeout: Duration at which sampling times out in seconds.
956+
944957
:return: CmdStanMCMC object
945958
"""
946959
if fixed_param is None:
@@ -1116,6 +1129,7 @@ def sample(
11161129
show_progress=show_progress,
11171130
show_console=show_console,
11181131
progress_hook=progress_hook,
1132+
timeout=timeout,
11191133
)
11201134
if show_progress and progress_hook is not None:
11211135
progress_hook("Done", -1) # -1 == all chains finished
@@ -1131,6 +1145,8 @@ def sample(
11311145
sys.stdout.write('\n')
11321146
get_logger().info('CmdStan done processing.')
11331147

1148+
runset.raise_for_timeouts()
1149+
11341150
get_logger().debug('runset\n%s', repr(runset))
11351151

11361152
# hack needed to parse CSV files if model has no params
@@ -1186,6 +1202,7 @@ def generate_quantities(
11861202
show_console: bool = False,
11871203
refresh: Optional[int] = None,
11881204
time_fmt: str = "%Y%m%d%H%M%S",
1205+
timeout: Optional[float] = None,
11891206
) -> CmdStanGQ:
11901207
"""
11911208
Run CmdStan's generate_quantities method which runs the generated
@@ -1244,6 +1261,8 @@ def generate_quantities(
12441261
:meth:`~datetime.datetime.strftime` to decide the file names for
12451262
output CSVs. Defaults to "%Y%m%d%H%M%S"
12461263
1264+
:param timeout: Duration at which generation times out in seconds.
1265+
12471266
:return: CmdStanGQ object
12481267
"""
12491268
if isinstance(mcmc_sample, CmdStanMCMC):
@@ -1306,8 +1325,10 @@ def generate_quantities(
13061325
runset,
13071326
i,
13081327
show_console=show_console,
1328+
timeout=timeout,
13091329
)
13101330

1331+
runset.raise_for_timeouts()
13111332
errors = runset.get_err_msgs()
13121333
if errors:
13131334
msg = (
@@ -1343,6 +1364,7 @@ def variational(
13431364
show_console: bool = False,
13441365
refresh: Optional[int] = None,
13451366
time_fmt: str = "%Y%m%d%H%M%S",
1367+
timeout: Optional[float] = None,
13461368
) -> CmdStanVB:
13471369
"""
13481370
Run CmdStan's variational inference algorithm to approximate
@@ -1435,6 +1457,9 @@ def variational(
14351457
:meth:`~datetime.datetime.strftime` to decide the file names for
14361458
output CSVs. Defaults to "%Y%m%d%H%M%S"
14371459
1460+
:param timeout: Duration at which variational Bayesian inference times
1461+
out in seconds.
1462+
14381463
:return: CmdStanVB object
14391464
"""
14401465
variational_args = VariationalArgs(
@@ -1468,7 +1493,13 @@ def variational(
14681493

14691494
dummy_chain_id = 0
14701495
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
1471-
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
1496+
self._run_cmdstan(
1497+
runset,
1498+
dummy_chain_id,
1499+
show_console=show_console,
1500+
timeout=timeout,
1501+
)
1502+
runset.raise_for_timeouts()
14721503

14731504
# treat failure to converge as failure
14741505
transcript_file = runset.stdout_files[dummy_chain_id]
@@ -1504,9 +1535,8 @@ def variational(
15041535
'current value is {}.'.format(grad_samples)
15051536
)
15061537
else:
1507-
msg = (
1508-
'Variational algorithm failed.\n '
1509-
'Console output:\n{}'.format(contents)
1538+
msg = 'Error during variational inference: {}'.format(
1539+
runset.get_err_msgs()
15101540
)
15111541
raise RuntimeError(msg)
15121542
# pylint: disable=invalid-name
@@ -1520,6 +1550,7 @@ def _run_cmdstan(
15201550
show_progress: bool = False,
15211551
show_console: bool = False,
15221552
progress_hook: Optional[Callable[[str, int], None]] = None,
1553+
timeout: Optional[float] = None,
15231554
) -> None:
15241555
"""
15251556
Helper function which encapsulates call to CmdStan.
@@ -1556,6 +1587,20 @@ def _run_cmdstan(
15561587
env=os.environ,
15571588
universal_newlines=True,
15581589
)
1590+
if timeout:
1591+
1592+
def _timer_target() -> None:
1593+
# Abort if the process has already terminated.
1594+
if proc.poll() is not None:
1595+
return
1596+
proc.terminate()
1597+
runset._set_timeout_flag(idx, True)
1598+
1599+
timer = threading.Timer(timeout, _timer_target)
1600+
timer.daemon = True
1601+
timer.start()
1602+
else:
1603+
timer = None
15591604
while proc.poll() is None:
15601605
if proc.stdout is not None:
15611606
line = proc.stdout.readline()
@@ -1569,6 +1614,8 @@ def _run_cmdstan(
15691614
stdout, _ = proc.communicate()
15701615
retcode = proc.returncode
15711616
runset._set_retcode(idx, retcode)
1617+
if timer:
1618+
timer.cancel()
15721619

15731620
if stdout:
15741621
fd_out.write(stdout)

cmdstanpy/stanfit/runset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
else:
4545
self._num_procs = 1
4646
self._retcodes = [-1 for _ in range(self._num_procs)]
47+
self._timeout_flags = [False for _ in range(self._num_procs)]
4748
if chain_ids is None:
4849
chain_ids = [i + 1 for i in range(chains)]
4950
self._chain_ids = chain_ids
@@ -230,6 +231,10 @@ def _set_retcode(self, idx: int, val: int) -> None:
230231
"""Set retcode at process[idx] to val."""
231232
self._retcodes[idx] = val
232233

234+
def _set_timeout_flag(self, idx: int, val: bool) -> None:
235+
"""Set timeout_flag at process[idx] to val."""
236+
self._timeout_flags[idx] = val
237+
233238
def get_err_msgs(self) -> str:
234239
"""Checks console messages for each CmdStan run."""
235240
msgs = []
@@ -294,3 +299,10 @@ def save_csvfiles(self, dir: Optional[str] = None) -> None:
294299
raise ValueError(
295300
'Cannot save to file: {}'.format(to_path)
296301
) from e
302+
303+
def raise_for_timeouts(self) -> None:
304+
if any(self._timeout_flags):
305+
raise TimeoutError(
306+
f"{sum(self._timeout_flags)} of {self.num_procs} processes "
307+
"timed out"
308+
)

test/data/timeout.stan

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
data {
2+
// Indicator for endless looping.
3+
int loop;
4+
}
5+
6+
transformed data {
7+
// Maybe loop forever so the model times out.
8+
real y = 1;
9+
while(loop && y) {
10+
y += 1;
11+
}
12+
}
13+
14+
parameters {
15+
real x;
16+
}
17+
18+
model {
19+
// A nice model so we can get a fit for the `generated_quantities` call.
20+
x ~ normal(0, 1);
21+
}

test/test_generate_quantities.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,15 @@ def test_attrs(self):
476476
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
477477
dummy = fit.c
478478

479+
def test_timeout(self):
480+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
481+
timeout_model = CmdStanModel(stan_file=stan)
482+
fit = timeout_model.sample(data={'loop': 0}, chains=1, iter_sampling=10)
483+
with self.assertRaises(TimeoutError):
484+
timeout_model.generate_quantities(
485+
timeout=0.1, mcmc_sample=fit, data={'loop': 1}
486+
)
487+
479488

480489
if __name__ == '__main__':
481490
unittest.main()

test/test_optimize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,12 @@ def test_attrs(self):
643643
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
644644
dummy = fit.c
645645

646+
def test_timeout(self):
647+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
648+
timeout_model = CmdStanModel(stan_file=stan)
649+
with self.assertRaises(TimeoutError):
650+
timeout_model.optimize(data={'loop': 1}, timeout=0.1)
651+
646652

647653
if __name__ == '__main__':
648654
unittest.main()

test/test_sample.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,12 @@ def test_diagnostics(self):
19131913
self.assertEqual(fit.max_treedepths, None)
19141914
self.assertEqual(fit.divergences, None)
19151915

1916+
def test_timeout(self):
1917+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
1918+
timeout_model = CmdStanModel(stan_file=stan)
1919+
with self.assertRaises(TimeoutError):
1920+
timeout_model.sample(timeout=0.1, chains=1, data={'loop': 1})
1921+
19161922

19171923
if __name__ == '__main__':
19181924
unittest.main()

test/test_variational.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,12 @@ def test_attrs(self):
291291
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
292292
dummy = fit.c
293293

294+
def test_timeout(self):
295+
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
296+
timeout_model = CmdStanModel(stan_file=stan)
297+
with self.assertRaises(TimeoutError):
298+
timeout_model.variational(timeout=0.1, data={'loop': 1})
299+
294300

295301
if __name__ == '__main__':
296302
unittest.main()

0 commit comments

Comments
 (0)