1313from io import StringIO
1414from multiprocessing import cpu_count
1515from pathlib import Path
16+ import threading
1617from typing import Any , Callable , Dict , Iterable , List , Mapping , Optional , Union
1718
1819import ujson as json
@@ -568,6 +569,7 @@ def optimize(
568569 show_console : bool = False ,
569570 refresh : Optional [int ] = None ,
570571 time_fmt : str = "%Y%m%d%H%M%S" ,
572+ timeout : Optional [float ] = None ,
571573 ) -> CmdStanMLE :
572574 """
573575 Run the specified CmdStan optimize algorithm to produce a
@@ -667,6 +669,8 @@ def optimize(
667669 :meth:`~datetime.datetime.strftime` to decide the file names for
668670 output CSVs. Defaults to "%Y%m%d%H%M%S"
669671
672+ :param timeout: Duration at which optimization times out in seconds.
673+
670674 :return: CmdStanMLE object
671675 """
672676 optimize_args = OptimizeArgs (
@@ -698,7 +702,13 @@ def optimize(
698702 )
699703 dummy_chain_id = 0
700704 runset = RunSet (args = args , chains = 1 , time_fmt = time_fmt )
701- self ._run_cmdstan (runset , dummy_chain_id , show_console = show_console )
705+ self ._run_cmdstan (
706+ runset ,
707+ dummy_chain_id ,
708+ show_console = show_console ,
709+ timeout = timeout ,
710+ )
711+ runset .raise_for_timeouts ()
702712
703713 if not runset ._check_retcodes ():
704714 msg = "Error during optimization! Command '{}' failed: {}" .format (
@@ -744,6 +754,7 @@ def sample(
744754 show_console : bool = False ,
745755 refresh : Optional [int ] = None ,
746756 time_fmt : str = "%Y%m%d%H%M%S" ,
757+ timeout : Optional [float ] = None ,
747758 * ,
748759 force_one_process_per_chain : Optional [bool ] = None ,
749760 ) -> CmdStanMCMC :
@@ -941,6 +952,8 @@ def sample(
941952 model was compiled with STAN_THREADS=True, and utilize the
942953 parallel chain functionality if those conditions are met.
943954
955+ :param timeout: Duration at which sampling times out in seconds.
956+
944957 :return: CmdStanMCMC object
945958 """
946959 if fixed_param is None :
@@ -1116,6 +1129,7 @@ def sample(
11161129 show_progress = show_progress ,
11171130 show_console = show_console ,
11181131 progress_hook = progress_hook ,
1132+ timeout = timeout ,
11191133 )
11201134 if show_progress and progress_hook is not None :
11211135 progress_hook ("Done" , - 1 ) # -1 == all chains finished
@@ -1131,6 +1145,8 @@ def sample(
11311145 sys .stdout .write ('\n ' )
11321146 get_logger ().info ('CmdStan done processing.' )
11331147
1148+ runset .raise_for_timeouts ()
1149+
11341150 get_logger ().debug ('runset\n %s' , repr (runset ))
11351151
11361152 # hack needed to parse CSV files if model has no params
@@ -1186,6 +1202,7 @@ def generate_quantities(
11861202 show_console : bool = False ,
11871203 refresh : Optional [int ] = None ,
11881204 time_fmt : str = "%Y%m%d%H%M%S" ,
1205+ timeout : Optional [float ] = None ,
11891206 ) -> CmdStanGQ :
11901207 """
11911208 Run CmdStan's generate_quantities method which runs the generated
@@ -1244,6 +1261,8 @@ def generate_quantities(
12441261 :meth:`~datetime.datetime.strftime` to decide the file names for
12451262 output CSVs. Defaults to "%Y%m%d%H%M%S"
12461263
1264+ :param timeout: Duration at which generation times out in seconds.
1265+
12471266 :return: CmdStanGQ object
12481267 """
12491268 if isinstance (mcmc_sample , CmdStanMCMC ):
@@ -1306,8 +1325,10 @@ def generate_quantities(
13061325 runset ,
13071326 i ,
13081327 show_console = show_console ,
1328+ timeout = timeout ,
13091329 )
13101330
1331+ runset .raise_for_timeouts ()
13111332 errors = runset .get_err_msgs ()
13121333 if errors :
13131334 msg = (
@@ -1343,6 +1364,7 @@ def variational(
13431364 show_console : bool = False ,
13441365 refresh : Optional [int ] = None ,
13451366 time_fmt : str = "%Y%m%d%H%M%S" ,
1367+ timeout : Optional [float ] = None ,
13461368 ) -> CmdStanVB :
13471369 """
13481370 Run CmdStan's variational inference algorithm to approximate
@@ -1435,6 +1457,9 @@ def variational(
14351457 :meth:`~datetime.datetime.strftime` to decide the file names for
14361458 output CSVs. Defaults to "%Y%m%d%H%M%S"
14371459
1460+ :param timeout: Duration at which variational Bayesian inference times
1461+ out in seconds.
1462+
14381463 :return: CmdStanVB object
14391464 """
14401465 variational_args = VariationalArgs (
@@ -1468,7 +1493,13 @@ def variational(
14681493
14691494 dummy_chain_id = 0
14701495 runset = RunSet (args = args , chains = 1 , time_fmt = time_fmt )
1471- self ._run_cmdstan (runset , dummy_chain_id , show_console = show_console )
1496+ self ._run_cmdstan (
1497+ runset ,
1498+ dummy_chain_id ,
1499+ show_console = show_console ,
1500+ timeout = timeout ,
1501+ )
1502+ runset .raise_for_timeouts ()
14721503
14731504 # treat failure to converge as failure
14741505 transcript_file = runset .stdout_files [dummy_chain_id ]
@@ -1504,9 +1535,8 @@ def variational(
15041535 'current value is {}.' .format (grad_samples )
15051536 )
15061537 else :
1507- msg = (
1508- 'Variational algorithm failed.\n '
1509- 'Console output:\n {}' .format (contents )
1538+ msg = 'Error during variational inference: {}' .format (
1539+ runset .get_err_msgs ()
15101540 )
15111541 raise RuntimeError (msg )
15121542 # pylint: disable=invalid-name
@@ -1520,6 +1550,7 @@ def _run_cmdstan(
15201550 show_progress : bool = False ,
15211551 show_console : bool = False ,
15221552 progress_hook : Optional [Callable [[str , int ], None ]] = None ,
1553+ timeout : Optional [float ] = None ,
15231554 ) -> None :
15241555 """
15251556 Helper function which encapsulates call to CmdStan.
@@ -1556,6 +1587,20 @@ def _run_cmdstan(
15561587 env = os .environ ,
15571588 universal_newlines = True ,
15581589 )
1590+ if timeout :
1591+
1592+ def _timer_target () -> None :
1593+ # Abort if the process has already terminated.
1594+ if proc .poll () is not None :
1595+ return
1596+ proc .terminate ()
1597+ runset ._set_timeout_flag (idx , True )
1598+
1599+ timer = threading .Timer (timeout , _timer_target )
1600+ timer .daemon = True
1601+ timer .start ()
1602+ else :
1603+ timer = None
15591604 while proc .poll () is None :
15601605 if proc .stdout is not None :
15611606 line = proc .stdout .readline ()
@@ -1569,6 +1614,8 @@ def _run_cmdstan(
15691614 stdout , _ = proc .communicate ()
15701615 retcode = proc .returncode
15711616 runset ._set_retcode (idx , retcode )
1617+ if timer :
1618+ timer .cancel ()
15721619
15731620 if stdout :
15741621 fd_out .write (stdout )
0 commit comments