2929import pandas as pd
3030from tqdm .auto import tqdm
3131
32- from cmdstanpy import (
33- _CMDSTAN_REFRESH ,
34- _CMDSTAN_SAMPLING ,
35- _CMDSTAN_WARMUP ,
36- _TMPDIR ,
37- compilation ,
38- )
32+ from cmdstanpy import _CMDSTAN_SAMPLING , _CMDSTAN_WARMUP , _TMPDIR , compilation
3933from cmdstanpy .cmdstan_args import (
4034 CmdStanArgs ,
4135 GenerateQuantitiesArgs ,
@@ -1069,9 +1063,6 @@ def sample(
10691063 iter_total += _CMDSTAN_SAMPLING
10701064 else :
10711065 iter_total += iter_sampling
1072- if refresh is None :
1073- refresh = _CMDSTAN_REFRESH
1074- iter_total = iter_total // refresh + 2
10751066
10761067 progress_hook = self ._wrap_sampler_progress_hook (
10771068 chain_ids = chain_ids ,
@@ -2138,13 +2129,12 @@ def _wrap_sampler_progress_hook(
21382129 process, "Chain [id] Iteration" for multi-chain processing.
21392130 For the latter, manage array of pbars, update accordingly.
21402131 """
2141- pat = re .compile (r'Chain \[(\d* )\] ( Iteration.* )' )
2132+ chain_pat = re .compile (r'( Chain \[(\d+ )\] )? Iteration:\s+(\d+ )' )
21422133 pbars : Dict [int , tqdm ] = {
21432134 chain_id : tqdm (
21442135 total = total ,
2145- bar_format = "{desc} |{bar}| {elapsed} {postfix[0][value]}" ,
2146- postfix = [{"value" : "Status" }],
21472136 desc = f'chain { chain_id } ' ,
2137+ postfix = '(Warmup)' ,
21482138 colour = 'yellow' ,
21492139 )
21502140 for chain_id in chain_ids
@@ -2153,23 +2143,19 @@ def _wrap_sampler_progress_hook(
21532143 def progress_hook (line : str , idx : int ) -> None :
21542144 if line == "Done" :
21552145 for pbar in pbars .values ():
2156- pbar .postfix [ 0 ][ "value" ] = ' Sampling completed'
2146+ pbar .set_postfix_str ( '( Sampling completed)' )
21572147 pbar .update (total - pbar .n )
21582148 pbar .close ()
2159- else :
2160- match = pat .match (line )
2161- if match :
2162- idx = int (match .group (1 ))
2163- mline = match .group (2 ).strip ()
2164- elif line .startswith ("Iteration" ):
2165- mline = line
2166- idx = chain_ids [idx ]
2167- else :
2168- return
2169- if 'Sampling' in mline :
2170- pbars [idx ].colour = 'blue'
2171- pbars [idx ].update (1 )
2172- pbars [idx ].postfix [0 ]["value" ] = mline
2149+ elif (match := chain_pat .match (line )) is not None :
2150+ idx = int (match .group (2 ) or chain_ids [idx ])
2151+ current_iter = int (match .group (3 ))
2152+
2153+ pbar = pbars [idx ]
2154+ if pbar .colour == 'yellow' and 'Sampling' in line :
2155+ pbar .colour = 'blue'
2156+ pbar .set_postfix_str ('(Sampling)' )
2157+
2158+ pbar .update (current_iter - pbar .n )
21732159
21742160 return progress_hook
21752161
@@ -2225,8 +2211,7 @@ def diagnose(
22252211 Gradients are evaluated in the unconstrained space.
22262212 """
22272213
2228- with temp_single_json (data ) as _data , \
2229- temp_single_json (inits ) as _inits :
2214+ with temp_single_json (data ) as _data , temp_single_json (inits ) as _inits :
22302215 cmd = [
22312216 str (self .exe_file ),
22322217 "diagnose" ,
0 commit comments