Skip to content

Commit 96dd965

Browse files
authored
Merge pull request #577 from stan-dev/feature/555-compute-diagnostics
Feature/555 compute diagnostics
2 parents 39ee446 + 4e6159d commit 96dd965

17 files changed

+534
-164
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def validate(self) -> None:
828828
'0 and 2**32-1, found {}.'.format(self.seed)
829829
)
830830
if isinstance(self.seed, int):
831-
if self.seed < 0 or self.seed > 2 ** 32 - 1:
831+
if self.seed < 0 or self.seed > 2**32 - 1:
832832
raise ValueError(
833833
'Argument "seed" must be an integer between '
834834
'0 and 2**32-1, found {}.'.format(self.seed)
@@ -847,7 +847,7 @@ def validate(self) -> None:
847847
)
848848
)
849849
for seed in self.seed:
850-
if seed < 0 or seed > 2 ** 32 - 1:
850+
if seed < 0 or seed > 2**32 - 1:
851851
raise ValueError(
852852
'Argument "seed" must be an integer value'
853853
' between 0 and 2**32-1,'

cmdstanpy/install_cxx_toolchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from cmdstanpy.utils import pushd, validate_dir, wrap_url_progress_hook
2828

2929
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
30-
IS_64BITS = sys.maxsize > 2 ** 32
30+
IS_64BITS = sys.maxsize > 2**32
3131

3232

3333
def usage() -> None:

cmdstanpy/stanfit/mcmc.py

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
get_logger,
4444
scan_generated_quantities_csv,
4545
)
46-
4746
from .metadata import InferenceMetadata
4847
from .runset import RunSet
4948

@@ -78,31 +77,34 @@ def __init__(
7877
assert isinstance(
7978
sampler_args, SamplerArgs
8079
) # make the typechecker happy
81-
iter_sampling = sampler_args.iter_sampling
82-
if iter_sampling is None:
83-
self._iter_sampling = _CMDSTAN_SAMPLING
84-
else:
85-
self._iter_sampling = iter_sampling
86-
iter_warmup = sampler_args.iter_warmup
87-
if iter_warmup is None:
88-
self._iter_warmup = _CMDSTAN_WARMUP
89-
else:
90-
self._iter_warmup = iter_warmup
91-
thin = sampler_args.thin
92-
if thin is None:
93-
self._thin: int = _CMDSTAN_THIN
94-
else:
95-
self._thin = thin
80+
self._iter_sampling: int = _CMDSTAN_SAMPLING
81+
if sampler_args.iter_sampling is not None:
82+
self._iter_sampling = sampler_args.iter_sampling
83+
self._iter_warmup: int = _CMDSTAN_WARMUP
84+
if sampler_args.iter_warmup is not None:
85+
self._iter_warmup = sampler_args.iter_warmup
86+
self._thin: int = _CMDSTAN_THIN
87+
if sampler_args.thin is not None:
88+
self._thin = sampler_args.thin
9689
self._is_fixed_param = sampler_args.fixed_param
9790
self._save_warmup = sampler_args.save_warmup
9891
self._sig_figs = runset._args.sig_figs
92+
9993
# info from CSV values, instantiated lazily
94+
self._draws: np.ndarray = np.array(())
95+
# only valid when not is_fixed_param
10096
self._metric: np.ndarray = np.array(())
10197
self._step_size: np.ndarray = np.array(())
102-
self._draws: np.ndarray = np.array(())
98+
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
99+
self._max_treedepths: np.ndarray = np.zeros(
100+
self.runset.chains, dtype=int
101+
)
102+
103103
# info from CSV initial comments and header
104104
config = self._validate_csv_files()
105105
self._metadata: InferenceMetadata = InferenceMetadata(config)
106+
if not self._is_fixed_param:
107+
self._check_sampler_diagnostics()
106108

107109
def __repr__(self) -> str:
108110
repr = 'CmdStanMCMC: model={} chains={}{}'.format(
@@ -171,13 +173,15 @@ def column_names(self) -> Tuple[str, ...]:
171173
@property
172174
def metric_type(self) -> Optional[str]:
173175
"""
174-
Metric type used for adaptation, either 'diag_e' or 'dense_e'.
176+
Metric type used for adaptation, either 'diag_e' or 'dense_e', according
177+
to CmdStan arg 'metric'.
175178
When sampler algorithm 'fixed_param' is specified, metric_type is None.
176179
"""
177-
if self._is_fixed_param:
178-
return None
179-
# cmdstan arg name
180-
return self._metadata.cmdstan_config['metric'] # type: ignore
180+
return (
181+
self._metadata.cmdstan_config['metric']
182+
if not self._is_fixed_param
183+
else None
184+
)
181185

182186
@property
183187
def metric(self) -> Optional[np.ndarray]:
@@ -192,8 +196,7 @@ def metric(self) -> Optional[np.ndarray]:
192196
'Unit diagnonal metric, inverse mass matrix size unknown.'
193197
)
194198
return None
195-
if self._draws.shape == (0,):
196-
self._assemble_draws()
199+
self._assemble_draws()
197200
return self._metric
198201

199202
@property
@@ -202,11 +205,8 @@ def step_size(self) -> Optional[np.ndarray]:
202205
Step size used by sampler for each chain.
203206
When sampler algorithm 'fixed_param' is specified, step size is None.
204207
"""
205-
if self._is_fixed_param:
206-
return None
207-
if self._step_size.shape == (0,):
208-
self._assemble_draws()
209-
return self._step_size
208+
self._assemble_draws()
209+
return self._step_size if not self._is_fixed_param else None
210210

211211
@property
212212
def thin(self) -> int:
@@ -215,6 +215,23 @@ def thin(self) -> int:
215215
"""
216216
return self._thin
217217

218+
@property
219+
def divergences(self) -> Optional[np.ndarray]:
220+
"""
221+
Per-chain total number of post-warmup divergent iterations.
222+
When sampler algorithm 'fixed_param' is specified, returns None.
223+
"""
224+
return self._divergences if not self._is_fixed_param else None
225+
226+
@property
227+
def max_treedepths(self) -> Optional[np.ndarray]:
228+
"""
229+
Per-chain total number of post-warmup iterations where the NUTS sampler
230+
reached the maximum allowed treedepth.
231+
When sampler algorithm 'fixed_param' is specified, returns None.
232+
"""
233+
return self._max_treedepths if not self._is_fixed_param else None
234+
218235
def draws(
219236
self, *, inc_warmup: bool = False, concat_chains: bool = False
220237
) -> np.ndarray:
@@ -263,6 +280,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
263280
Checks that Stan CSV output files for all chains are consistent
264281
and returns dict containing config and column names.
265282
283+
Tabulates sampling iters which are divergent or at max treedepth
266284
Raises exception when inconsistencies detected.
267285
"""
268286
dzero = {}
@@ -276,6 +294,9 @@ def _validate_csv_files(self) -> Dict[str, Any]:
276294
save_warmup=self._save_warmup,
277295
thin=self._thin,
278296
)
297+
if not self._is_fixed_param:
298+
self._divergences[i] = dzero['ct_divergences']
299+
self._max_treedepths[i] = dzero['ct_max_treedepth']
279300
else:
280301
drest = check_sampler_csv(
281302
path=self.runset.csv_files[i],
@@ -312,13 +333,43 @@ def _validate_csv_files(self) -> Dict[str, Any]:
312333
drest[key],
313334
)
314335
)
336+
if not self._is_fixed_param:
337+
self._divergences[i] = drest['ct_divergences']
338+
self._max_treedepths[i] = drest['ct_max_treedepth']
315339
return dzero
316340

341+
def _check_sampler_diagnostics(self) -> None:
342+
"""
343+
Warn if any iterations ended in divergences or hit maxtreedepth.
344+
"""
345+
if np.any(self._divergences) or np.any(self._max_treedepths):
346+
diagnostics = ['Some chains may have failed to converge.']
347+
ct_iters = self.metadata.cmdstan_config['num_samples']
348+
for i in range(self.runset._chains):
349+
if self._divergences[i] > 0:
350+
diagnostics.append(
351+
f'Chain {i + 1} had {self._divergences[i]} '
352+
'divergent transitions '
353+
f'({((self._divergences[i]/ct_iters)*100):.1f}%)'
354+
)
355+
if self._max_treedepths[i] > 0:
356+
diagnostics.append(
357+
f'Chain {i + 1} had {self._max_treedepths[i]} '
358+
'iterations at max treedepth '
359+
f'({((self._max_treedepths[i]/ct_iters)*100):.1f}%)'
360+
)
361+
diagnostics.append(
362+
'Use function "diagnose()" to see further information.'
363+
)
364+
get_logger().warning('\n\t'.join(diagnostics))
365+
317366
def _assemble_draws(self) -> None:
318367
"""
319368
Allocates and populates the step size, metric, and sample arrays
320369
by parsing the validated stan_csv files.
321370
"""
371+
if self._draws.shape != (0,):
372+
return
322373
num_draws = self.num_draws_sampling
323374
sampling_iter_start = 0
324375
if self._save_warmup:

cmdstanpy/utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ def get_logger() -> logging.Logger:
6868
# add a default handler to the logger to INFO and higher
6969
handler = logging.StreamHandler()
7070
handler.setLevel(logging.INFO)
71-
handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
71+
handler.setFormatter(
72+
logging.Formatter(
73+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
74+
"%H:%M:%S",
75+
)
76+
)
7277
logger.addHandler(handler)
7378
return logger
7479

@@ -172,7 +177,7 @@ def cmdstan_path() -> str:
172177
Validate, then return CmdStan directory path.
173178
"""
174179
cmdstan = ''
175-
if 'CMDSTAN' in os.environ:
180+
if 'CMDSTAN' in os.environ and len(os.environ['CMDSTAN']) > 0:
176181
cmdstan = os.environ['CMDSTAN']
177182
else:
178183
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
@@ -291,7 +296,7 @@ def cxx_toolchain_path(
291296
if os.path.exists(os.path.join(toolchain_root, 'mingw64')):
292297
compiler_path = os.path.join(
293298
toolchain_root,
294-
'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32',
299+
'mingw64' if (sys.maxsize > 2**32) else 'mingw32',
295300
'bin',
296301
)
297302
if os.path.exists(compiler_path):
@@ -315,7 +320,7 @@ def cxx_toolchain_path(
315320
elif os.path.exists(os.path.join(toolchain_root, 'mingw_64')):
316321
compiler_path = os.path.join(
317322
toolchain_root,
318-
'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32',
323+
'mingw_64' if (sys.maxsize > 2**32) else 'mingw_32',
319324
'bin',
320325
)
321326
if os.path.exists(compiler_path):
@@ -367,7 +372,7 @@ def cxx_toolchain_path(
367372
if version not in ('35', '3.5', '3'):
368373
compiler_path = os.path.join(
369374
toolchain_root,
370-
'mingw64' if (sys.maxsize > 2 ** 32) else 'mingw32',
375+
'mingw64' if (sys.maxsize > 2**32) else 'mingw32',
371376
'bin',
372377
)
373378
if os.path.exists(compiler_path):
@@ -392,7 +397,7 @@ def cxx_toolchain_path(
392397
else:
393398
compiler_path = os.path.join(
394399
toolchain_root,
395-
'mingw_64' if (sys.maxsize > 2 ** 32) else 'mingw_32',
400+
'mingw_64' if (sys.maxsize > 2**32) else 'mingw_32',
396401
'bin',
397402
)
398403
if os.path.exists(compiler_path):
@@ -649,7 +654,7 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
649654
if not is_fixed_param:
650655
lineno = scan_warmup_iters(fd, dict, lineno)
651656
lineno = scan_hmc_params(fd, dict, lineno)
652-
lineno = scan_sampling_iters(fd, dict, lineno)
657+
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
653658
except ValueError as e:
654659
raise ValueError("Error in reading csv file: " + path) from e
655660
return dict
@@ -952,13 +957,21 @@ def scan_hmc_params(
952957

953958

954959
def scan_sampling_iters(
955-
fd: TextIO, config_dict: Dict[str, Any], lineno: int
960+
fd: TextIO, config_dict: Dict[str, Any], lineno: int, is_fixed_param: bool
956961
) -> int:
957962
"""
958963
Parse sampling iteration, save number of iterations to config_dict.
964+
Also save number of divergences, max_treedepth hits
959965
"""
960966
draws_found = 0
961967
num_cols = len(config_dict['column_names'])
968+
if not is_fixed_param:
969+
idx_divergent = config_dict['column_names'].index('divergent__')
970+
idx_treedepth = config_dict['column_names'].index('treedepth__')
971+
max_treedepth = config_dict['max_depth']
972+
ct_divergences = 0
973+
ct_max_treedepth = 0
974+
962975
cur_pos = fd.tell()
963976
line = fd.readline().strip()
964977
while len(line) > 0 and not line.startswith('#'):
@@ -976,8 +989,16 @@ def scan_sampling_iters(
976989
)
977990
cur_pos = fd.tell()
978991
line = fd.readline().strip()
979-
config_dict['draws_sampling'] = draws_found
992+
if not is_fixed_param:
993+
ct_divergences += int(data[idx_divergent]) # type: ignore
994+
if int(data[idx_treedepth]) == max_treedepth: # type: ignore
995+
ct_max_treedepth += 1
996+
980997
fd.seek(cur_pos)
998+
config_dict['draws_sampling'] = draws_found
999+
if not is_fixed_param:
1000+
config_dict['ct_divergences'] = ct_divergences
1001+
config_dict['ct_max_treedepth'] = ct_max_treedepth
9811002
return lineno
9821003

9831004

0 commit comments

Comments
 (0)