1313from io import StringIO
1414from multiprocessing import cpu_count
1515from pathlib import Path
16+ import threading
1617from typing import Any , Callable , Dict , Iterable , List , Mapping , Optional , Union
1718
1819import ujson as json
@@ -593,6 +594,7 @@ def optimize(
593594 show_console : bool = False ,
594595 refresh : Optional [int ] = None ,
595596 time_fmt : str = "%Y%m%d%H%M%S" ,
597+ timeout : Optional [float ] = None ,
596598 ) -> CmdStanMLE :
597599 """
598600 Run the specified CmdStan optimize algorithm to produce a
@@ -692,6 +694,8 @@ def optimize(
692694 :meth:`~datetime.datetime.strftime` to decide the file names for
693695 output CSVs. Defaults to "%Y%m%d%H%M%S"
694696
697+ :param timeout: Duration at which optimization times out in seconds.
698+
695699 :return: CmdStanMLE object
696700 """
697701 optimize_args = OptimizeArgs (
@@ -723,7 +727,8 @@ def optimize(
723727 )
724728 dummy_chain_id = 0
725729 runset = RunSet (args = args , chains = 1 , time_fmt = time_fmt )
726- self ._run_cmdstan (runset , dummy_chain_id , show_console = show_console )
730+ self ._run_cmdstan (runset , dummy_chain_id , show_console = show_console ,
731+ timeout = timeout )
727732
728733 if not runset ._check_retcodes ():
729734 msg = 'Error during optimization: {}' .format (runset .get_err_msgs ())
@@ -767,6 +772,7 @@ def sample(
767772 show_console : bool = False ,
768773 refresh : Optional [int ] = None ,
769774 time_fmt : str = "%Y%m%d%H%M%S" ,
775+ timeout : Optional [float ] = None ,
770776 * ,
771777 force_one_process_per_chain : Optional [bool ] = None ,
772778 ) -> CmdStanMCMC :
@@ -964,6 +970,8 @@ def sample(
964970 model was compiled with STAN_THREADS=True, and utilize the
965971 parallel chain functionality if those conditions are met.
966972
973+ :param timeout: Duration at which sampling times out in seconds.
974+
967975 :return: CmdStanMCMC object
968976 """
969977 if fixed_param is None :
@@ -1139,6 +1147,7 @@ def sample(
11391147 show_progress = show_progress ,
11401148 show_console = show_console ,
11411149 progress_hook = progress_hook ,
1150+ timeout = timeout ,
11421151 )
11431152 if show_progress and progress_hook is not None :
11441153 progress_hook ("Done" , - 1 ) # -1 == all chains finished
@@ -1209,6 +1218,7 @@ def generate_quantities(
12091218 show_console : bool = False ,
12101219 refresh : Optional [int ] = None ,
12111220 time_fmt : str = "%Y%m%d%H%M%S" ,
1221+ timeout : Optional [float ] = None ,
12121222 ) -> CmdStanGQ :
12131223 """
12141224 Run CmdStan's generate_quantities method which runs the generated
@@ -1267,6 +1277,8 @@ def generate_quantities(
12671277 :meth:`~datetime.datetime.strftime` to decide the file names for
12681278 output CSVs. Defaults to "%Y%m%d%H%M%S"
12691279
1280+ :param timeout: Duration at which generation times out in seconds.
1281+
12701282 :return: CmdStanGQ object
12711283 """
12721284 if isinstance (mcmc_sample , CmdStanMCMC ):
@@ -1329,6 +1341,7 @@ def generate_quantities(
13291341 runset ,
13301342 i ,
13311343 show_console = show_console ,
1344+ timeout = timeout ,
13321345 )
13331346
13341347 errors = runset .get_err_msgs ()
@@ -1366,6 +1379,7 @@ def variational(
13661379 show_console : bool = False ,
13671380 refresh : Optional [int ] = None ,
13681381 time_fmt : str = "%Y%m%d%H%M%S" ,
1382+ timeout : Optional [float ] = None ,
13691383 ) -> CmdStanVB :
13701384 """
13711385 Run CmdStan's variational inference algorithm to approximate
@@ -1458,6 +1472,9 @@ def variational(
14581472 :meth:`~datetime.datetime.strftime` to decide the file names for
14591473 output CSVs. Defaults to "%Y%m%d%H%M%S"
14601474
1475+ :param timeout: Duration at which variational Bayesian inference times
1476+ out in seconds.
1477+
14611478 :return: CmdStanVB object
14621479 """
14631480 variational_args = VariationalArgs (
@@ -1491,7 +1508,8 @@ def variational(
14911508
14921509 dummy_chain_id = 0
14931510 runset = RunSet (args = args , chains = 1 , time_fmt = time_fmt )
1494- self ._run_cmdstan (runset , dummy_chain_id , show_console = show_console )
1511+ self ._run_cmdstan (runset , dummy_chain_id , show_console = show_console ,
1512+ timeout = timeout )
14951513
14961514 # treat failure to converge as failure
14971515 transcript_file = runset .stdout_files [dummy_chain_id ]
@@ -1541,6 +1559,7 @@ def _run_cmdstan(
15411559 show_progress : bool = False ,
15421560 show_console : bool = False ,
15431561 progress_hook : Optional [Callable [[str , int ], None ]] = None ,
1562+ timeout : Optional [float ] = None ,
15441563 ) -> None :
15451564 """
15461565 Helper function which encapsulates call to CmdStan.
@@ -1577,6 +1596,12 @@ def _run_cmdstan(
15771596 env = os .environ ,
15781597 universal_newlines = True ,
15791598 )
1599+ if timeout :
1600+ timer = threading .Timer (timeout , proc .terminate )
1601+ timer .setDaemon (True )
1602+ timer .start ()
1603+ else :
1604+ timer = None
15801605 while proc .poll () is None :
15811606 if proc .stdout is not None :
15821607 line = proc .stdout .readline ()
@@ -1589,6 +1614,8 @@ def _run_cmdstan(
15891614
15901615 stdout , _ = proc .communicate ()
15911616 retcode = proc .returncode
1617+ if timer and retcode == - 15 :
1618+ retcode = 60
15921619 runset ._set_retcode (idx , retcode )
15931620
15941621 if stdout :
0 commit comments