4343 get_logger ,
4444 scan_generated_quantities_csv ,
4545)
46-
4746from .metadata import InferenceMetadata
4847from .runset import RunSet
4948
@@ -78,31 +77,34 @@ def __init__(
7877 assert isinstance (
7978 sampler_args , SamplerArgs
8079 ) # make the typechecker happy
81- iter_sampling = sampler_args .iter_sampling
82- if iter_sampling is None :
83- self ._iter_sampling = _CMDSTAN_SAMPLING
84- else :
85- self ._iter_sampling = iter_sampling
86- iter_warmup = sampler_args .iter_warmup
87- if iter_warmup is None :
88- self ._iter_warmup = _CMDSTAN_WARMUP
89- else :
90- self ._iter_warmup = iter_warmup
91- thin = sampler_args .thin
92- if thin is None :
93- self ._thin : int = _CMDSTAN_THIN
94- else :
95- self ._thin = thin
80+ self ._iter_sampling : int = _CMDSTAN_SAMPLING
81+ if sampler_args .iter_sampling is not None :
82+ self ._iter_sampling = sampler_args .iter_sampling
83+ self ._iter_warmup : int = _CMDSTAN_WARMUP
84+ if sampler_args .iter_warmup is not None :
85+ self ._iter_warmup = sampler_args .iter_warmup
86+ self ._thin : int = _CMDSTAN_THIN
87+ if sampler_args .thin is not None :
88+ self ._thin = sampler_args .thin
9689 self ._is_fixed_param = sampler_args .fixed_param
9790 self ._save_warmup = sampler_args .save_warmup
9891 self ._sig_figs = runset ._args .sig_figs
92+
9993 # info from CSV values, instantiated lazily
94+ self ._draws : np .ndarray = np .array (())
95+ # only valid when not is_fixed_param
10096 self ._metric : np .ndarray = np .array (())
10197 self ._step_size : np .ndarray = np .array (())
102- self ._draws : np .ndarray = np .array (())
98+ self ._divergences : np .ndarray = np .zeros (self .runset .chains , dtype = int )
99+ self ._max_treedepths : np .ndarray = np .zeros (
100+ self .runset .chains , dtype = int
101+ )
102+
103103 # info from CSV initial comments and header
104104 config = self ._validate_csv_files ()
105105 self ._metadata : InferenceMetadata = InferenceMetadata (config )
106+ if not self ._is_fixed_param :
107+ self ._check_sampler_diagnostics ()
106108
107109 def __repr__ (self ) -> str :
108110 repr = 'CmdStanMCMC: model={} chains={}{}' .format (
@@ -171,13 +173,15 @@ def column_names(self) -> Tuple[str, ...]:
171173 @property
172174 def metric_type (self ) -> Optional [str ]:
173175 """
174- Metric type used for adaptation, either 'diag_e' or 'dense_e'.
176+ Metric type used for adaptation, either 'diag_e' or 'dense_e', according
177+ to CmdStan arg 'metric'.
175178 When sampler algorithm 'fixed_param' is specified, metric_type is None.
176179 """
177- if self ._is_fixed_param :
178- return None
179- # cmdstan arg name
180- return self ._metadata .cmdstan_config ['metric' ] # type: ignore
180+ return (
181+ self ._metadata .cmdstan_config ['metric' ]
182+ if not self ._is_fixed_param
183+ else None
184+ )
181185
182186 @property
183187 def metric (self ) -> Optional [np .ndarray ]:
@@ -192,8 +196,7 @@ def metric(self) -> Optional[np.ndarray]:
192196 'Unit diagnonal metric, inverse mass matrix size unknown.'
193197 )
194198 return None
195- if self ._draws .shape == (0 ,):
196- self ._assemble_draws ()
199+ self ._assemble_draws ()
197200 return self ._metric
198201
199202 @property
@@ -202,11 +205,8 @@ def step_size(self) -> Optional[np.ndarray]:
202205 Step size used by sampler for each chain.
203206 When sampler algorithm 'fixed_param' is specified, step size is None.
204207 """
205- if self ._is_fixed_param :
206- return None
207- if self ._step_size .shape == (0 ,):
208- self ._assemble_draws ()
209- return self ._step_size
208+ self ._assemble_draws ()
209+ return self ._step_size if not self ._is_fixed_param else None
210210
211211 @property
212212 def thin (self ) -> int :
@@ -215,6 +215,23 @@ def thin(self) -> int:
215215 """
216216 return self ._thin
217217
218+ @property
219+ def divergences (self ) -> Optional [np .ndarray ]:
220+ """
221+ Per-chain total number of post-warmup divergent iterations.
222+ When sampler algorithm 'fixed_param' is specified, returns None.
223+ """
224+ return self ._divergences if not self ._is_fixed_param else None
225+
226+ @property
227+ def max_treedepths (self ) -> Optional [np .ndarray ]:
228+ """
229+ Per-chain total number of post-warmup iterations where the NUTS sampler
230+ reached the maximum allowed treedepth.
231+ When sampler algorithm 'fixed_param' is specified, returns None.
232+ """
233+ return self ._max_treedepths if not self ._is_fixed_param else None
234+
218235 def draws (
219236 self , * , inc_warmup : bool = False , concat_chains : bool = False
220237 ) -> np .ndarray :
@@ -263,6 +280,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
263280 Checks that Stan CSV output files for all chains are consistent
264281 and returns dict containing config and column names.
265282
283+ Tabulates sampling iters which are divergent or at max treedepth
266284 Raises exception when inconsistencies detected.
267285 """
268286 dzero = {}
@@ -276,6 +294,9 @@ def _validate_csv_files(self) -> Dict[str, Any]:
276294 save_warmup = self ._save_warmup ,
277295 thin = self ._thin ,
278296 )
297+ if not self ._is_fixed_param :
298+ self ._divergences [i ] = dzero ['ct_divergences' ]
299+ self ._max_treedepths [i ] = dzero ['ct_max_treedepth' ]
279300 else :
280301 drest = check_sampler_csv (
281302 path = self .runset .csv_files [i ],
@@ -312,13 +333,43 @@ def _validate_csv_files(self) -> Dict[str, Any]:
312333 drest [key ],
313334 )
314335 )
336+ if not self ._is_fixed_param :
337+ self ._divergences [i ] = drest ['ct_divergences' ]
338+ self ._max_treedepths [i ] = drest ['ct_max_treedepth' ]
315339 return dzero
316340
341+ def _check_sampler_diagnostics (self ) -> None :
342+ """
343+ Warn if any iterations ended in divergences or hit maxtreedepth.
344+ """
345+ if np .any (self ._divergences ) or np .any (self ._max_treedepths ):
346+ diagnostics = ['Some chains may have failed to converge.' ]
347+ ct_iters = self .metadata .cmdstan_config ['num_samples' ]
348+ for i in range (self .runset ._chains ):
349+ if self ._divergences [i ] > 0 :
350+ diagnostics .append (
351+ f'Chain { i + 1 } had { self ._divergences [i ]} '
352+ 'divergent transitions '
353+ f'({ ((self ._divergences [i ]/ ct_iters )* 100 ):.1f} %)'
354+ )
355+ if self ._max_treedepths [i ] > 0 :
356+ diagnostics .append (
357+ f'Chain { i + 1 } had { self ._max_treedepths [i ]} '
358+ 'iterations at max treedepth '
359+ f'({ ((self ._max_treedepths [i ]/ ct_iters )* 100 ):.1f} %)'
360+ )
361+ diagnostics .append (
362+ 'Use function "diagnose()" to see further information.'
363+ )
364+ get_logger ().warning ('\n \t ' .join (diagnostics ))
365+
317366 def _assemble_draws (self ) -> None :
318367 """
319368 Allocates and populates the step size, metric, and sample arrays
320369 by parsing the validated stan_csv files.
321370 """
371+ if self ._draws .shape != (0 ,):
372+ return
322373 num_draws = self .num_draws_sampling
323374 sampling_iter_start = 0
324375 if self ._save_warmup :
0 commit comments