Skip to content

Commit c9f1b05

Browse files
committed
Raise TimeoutError.
1 parent a4a7f21 commit c9f1b05

File tree

6 files changed

+31
-10
lines changed

6 files changed

+31
-10
lines changed

cmdstanpy/model.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ def optimize(
733733
show_console=show_console,
734734
timeout=timeout,
735735
)
736+
runset.raise_for_timeouts()
736737

737738
if not runset._check_retcodes():
738739
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
@@ -1167,6 +1168,8 @@ def sample(
11671168
sys.stdout.write('\n')
11681169
get_logger().info('CmdStan done processing.')
11691170

1171+
runset.raise_for_timeouts()
1172+
11701173
get_logger().debug('runset\n%s', repr(runset))
11711174

11721175
# hack needed to parse CSV files if model has no params
@@ -1348,6 +1351,7 @@ def generate_quantities(
13481351
timeout=timeout,
13491352
)
13501353

1354+
runset.raise_for_timeouts()
13511355
errors = runset.get_err_msgs()
13521356
if errors:
13531357
msg = (
@@ -1518,6 +1522,7 @@ def variational(
15181522
show_console=show_console,
15191523
timeout=timeout,
15201524
)
1525+
runset.raise_for_timeouts()
15211526

15221527
# treat failure to converge as failure
15231528
transcript_file = runset.stdout_files[dummy_chain_id]
@@ -1606,7 +1611,15 @@ def _run_cmdstan(
16061611
universal_newlines=True,
16071612
)
16081613
if timeout:
1609-
timer = threading.Timer(timeout, proc.terminate)
1614+
1615+
def _timer_target():
1616+
# Abort if the process has already terminated.
1617+
if proc.poll() is not None:
1618+
return
1619+
proc.terminate()
1620+
runset._set_timeout_flag(idx, True)
1621+
1622+
timer = threading.Timer(timeout, _timer_target)
16101623
timer.setDaemon(True)
16111624
timer.start()
16121625
else:
@@ -1623,11 +1636,9 @@ def _run_cmdstan(
16231636

16241637
stdout, _ = proc.communicate()
16251638
retcode = proc.returncode
1639+
runset._set_retcode(idx, retcode)
16261640
if timer:
16271641
timer.cancel()
1628-
if retcode == -15:
1629-
retcode = 60
1630-
runset._set_retcode(idx, retcode)
16311642

16321643
if stdout:
16331644
fd_out.write(stdout)

cmdstanpy/stanfit/runset.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
else:
4545
self._num_procs = 1
4646
self._retcodes = [-1 for _ in range(self._num_procs)]
47+
self._timeout_flags = [False for _ in range(self._num_procs)]
4748
if chain_ids is None:
4849
chain_ids = [i + 1 for i in range(chains)]
4950
self._chain_ids = chain_ids
@@ -230,12 +231,14 @@ def _set_retcode(self, idx: int, val: int) -> None:
230231
"""Set retcode at process[idx] to val."""
231232
self._retcodes[idx] = val
232233

234+
def _set_timeout_flag(self, idx: int, val: bool) -> None:
235+
"""Set timeout_flag at process[idx] to val."""
236+
self._timeout_flags[idx] = val
237+
233238
def get_err_msgs(self) -> str:
234239
"""Checks console messages for each CmdStan run."""
235240
msgs = []
236241
for i in range(self._num_procs):
237-
if self._retcodes[i] == 60:
238-
msgs.append("processing timed out")
239242
if (
240243
os.path.exists(self._stdout_files[i])
241244
and os.stat(self._stdout_files[i]).st_size > 0
@@ -296,3 +299,10 @@ def save_csvfiles(self, dir: Optional[str] = None) -> None:
296299
raise ValueError(
297300
'Cannot save to file: {}'.format(to_path)
298301
) from e
302+
303+
def raise_for_timeouts(self) -> None:
304+
if any(self._timeout_flags):
305+
raise TimeoutError(
306+
f"{sum(self._timeout_flags)} of {self.num_procs} processes "
307+
"timed out"
308+
)

test/test_generate_quantities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def test_timeout(self):
480480
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
481481
timeout_model = CmdStanModel(stan_file=stan)
482482
fit = timeout_model.sample(data={'loop': 0}, chains=1, iter_sampling=10)
483-
with self.assertRaisesRegex(RuntimeError, 'processing timed out'):
483+
with self.assertRaises(TimeoutError):
484484
timeout_model.generate_quantities(
485485
timeout=0.1, mcmc_sample=fit, data={'loop': 1}
486486
)

test/test_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def test_attrs(self):
637637
def test_timeout(self):
638638
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
639639
timeout_model = CmdStanModel(stan_file=stan)
640-
with self.assertRaisesRegex(RuntimeError, 'processing timed out'):
640+
with self.assertRaises(TimeoutError):
641641
timeout_model.optimize(data={'loop': 1}, timeout=0.1)
642642

643643

test/test_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1916,7 +1916,7 @@ def test_diagnostics(self):
19161916
def test_timeout(self):
19171917
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
19181918
timeout_model = CmdStanModel(stan_file=stan)
1919-
with self.assertRaisesRegex(RuntimeError, 'processing timed out'):
1919+
with self.assertRaises(TimeoutError):
19201920
timeout_model.sample(timeout=0.1, chains=1, data={'loop': 1})
19211921

19221922

test/test_variational.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def test_attrs(self):
294294
def test_timeout(self):
295295
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
296296
timeout_model = CmdStanModel(stan_file=stan)
297-
with self.assertRaisesRegex(RuntimeError, 'processing timed out'):
297+
with self.assertRaises(TimeoutError):
298298
timeout_model.variational(timeout=0.1, data={'loop': 1})
299299

300300

0 commit comments

Comments
 (0)