Skip to content

Commit e974a13

Browse files
committed
Add diagnose method to CmdStanModel.
1 parent 9a59a1a commit e974a13

File tree

5 files changed

+281
-0
lines changed

5 files changed

+281
-0
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Method(Enum):
3131
VARIATIONAL = auto()
3232
LAPLACE = auto()
3333
PATHFINDER = auto()
34+
DIAGNOSE = auto()
3435

3536
def __repr__(self) -> str:
3637
return '<%s.%s>' % (self.__class__.__name__, self.name)
@@ -736,6 +737,47 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
736737
return cmd
737738

738739

740+
class DiagnoseArgs:
741+
"""Arguments needed for diagnostics method."""
742+
743+
def __init__(
744+
self,
745+
test: Optional[str] = None,
746+
epsilon: Optional[float] = None,
747+
error: Optional[float] = None,
748+
) -> None:
749+
self.test = test
750+
self.epsilon = epsilon
751+
self.error = error
752+
753+
def validate(
754+
self, chains: Optional[int] = None # pylint: disable=unused-argument
755+
) -> None:
756+
"""
757+
Check argument correctness and consistency.
758+
"""
759+
if self.test is not None and self.test != "gradient":
760+
raise ValueError("Only testing gradient is supported.")
761+
762+
positive_float(self.epsilon, 'epsilon')
763+
positive_float(self.error, 'error')
764+
765+
# pylint: disable=unused-argument
766+
def compose(self, idx: int, cmd: List[str]) -> List[str]:
767+
"""
768+
Compose CmdStan command for method-specific non-default arguments.
769+
"""
770+
cmd.append('method=diagnose')
771+
cmd.append('test=gradient')
772+
if self.test:
773+
cmd.append(f'test={self.test}')
774+
if self.epsilon is not None:
775+
cmd.append(f'epsilon={self.epsilon}')
776+
if self.error is not None:
777+
cmd.append(f'error={self.error}')
778+
return cmd
779+
780+
739781
class CmdStanArgs:
740782
"""
741783
Container for CmdStan command line arguments.
@@ -755,6 +797,7 @@ def __init__(
755797
VariationalArgs,
756798
LaplaceArgs,
757799
PathfinderArgs,
800+
DiagnoseArgs,
758801
],
759802
data: Union[Mapping[str, Any], str, None] = None,
760803
seed: Union[int, List[int], None] = None,
@@ -790,6 +833,8 @@ def __init__(
790833
self.method = Method.LAPLACE
791834
elif isinstance(method_args, PathfinderArgs):
792835
self.method = Method.PATHFINDER
836+
elif isinstance(method_args, DiagnoseArgs):
837+
self.method = Method.DIAGNOSE
793838
else:
794839
raise ValueError(
795840
'Unsupported method args type: {}'.format(type(method_args))

cmdstanpy/model.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from cmdstanpy.cmdstan_args import (
4040
CmdStanArgs,
41+
DiagnoseArgs,
4142
GenerateQuantitiesArgs,
4243
LaplaceArgs,
4344
Method,
@@ -53,6 +54,7 @@
5354
CmdStanMLE,
5455
CmdStanPathfinder,
5556
CmdStanVB,
57+
CmdStanDiagnose,
5658
RunSet,
5759
from_csv,
5860
)
@@ -2203,3 +2205,117 @@ def progress_hook(line: str, idx: int) -> None:
22032205
pbars[idx].postfix[0]["value"] = mline
22042206

22052207
return progress_hook
2208+
2209+
def diagnose(
2210+
self,
2211+
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
2212+
seed: Optional[int] = None,
2213+
inits: Optional[float] = None,
2214+
output_dir: OptionalPath = None,
2215+
sig_figs: Optional[int] = None,
2216+
show_console: bool = False,
2217+
time_fmt: str = "%Y%m%d%H%M%S",
2218+
timeout: Optional[float] = None,
2219+
epsilon: Optional[float] = None,
2220+
error: Optional[float] = None,
2221+
require_gradients_ok: bool = True,
2222+
) -> CmdStanDiagnose:
2223+
"""
2224+
Run diagnostics to calculate the gradients of the initial state and
2225+
compare them with gradients calculated by finite differences.
2226+
Discrepancies between the two indicate that there is a problem with the
2227+
model or initial states or else there is a bug in Stan.
2228+
2229+
:param data: Values for all data variables in the model, specified
2230+
either as a dictionary with entries matching the data variables,
2231+
or as the path of a data file in JSON or Rdump format.
2232+
2233+
:param seed: The seed for random number generator. Must be an integer
2234+
between 0 and 2^32 - 1. If unspecified,
2235+
:func:`numpy.random.default_rng` is used to generate a seed.
2236+
2237+
:param inits: Specifies how the sampler initializes parameter values.
2238+
Initialization is either uniform random on a range centered on 0,
2239+
exactly 0, or a dictionary or file of initial values for some or
2240+
all parameters in the model. The default initialization behavior
2241+
will initialize all parameter values on range [-2, 2] on the
2242+
*unconstrained* support. The following value types are allowed:
2243+
2244+
* Single number, n > 0 - initialization range is [-n, n].
2245+
* 0 - all parameters are initialized to 0.
2246+
* dictionary - pairs parameter name : initial value.
2247+
* string - pathname to a JSON or Rdump data file.
2248+
2249+
:param output_dir: Name of the directory to which CmdStan output
2250+
files are written. If unspecified, output files will be written
2251+
to a temporary directory which is deleted upon session exit.
2252+
2253+
:param sig_figs: Numerical precision used for output CSV and text files.
2254+
Must be an integer between 1 and 18. If unspecified, the default
2255+
precision for the system file I/O is used; the usual value is 6.
2256+
Introduced in CmdStan-2.25.
2257+
2258+
:param show_console: If ``True``, stream CmdStan messages sent to stdout
2259+
and stderr to the console. Default is ``False``.
2260+
2261+
:param time_fmt: A format string passed to
2262+
:meth:`~datetime.datetime.strftime` to decide the file names for
2263+
output CSVs. Defaults to "%Y%m%d%H%M%S"
2264+
2265+
:param timeout: Duration at which the diagnostic command times out in
2266+
seconds. Defaults to None.
2267+
2268+
:param epsilon: Step size for finite difference gradients.
2269+
2270+
:param error: Absolute error threshold for comparing autodiff and finite
2271+
difference gradients.
2272+
2273+
:param require_gradients_ok: Whether or not to raise an error if Stan
2274+
reports that the difference between autodiff gradients and finite
2275+
difference gradients exceed the error threshold.
2276+
2277+
:return: A :class:`CmdStanDiagnose` object.
2278+
"""
2279+
diagnose_args = DiagnoseArgs(
2280+
epsilon=epsilon,
2281+
error=error,
2282+
)
2283+
2284+
with temp_single_json(data) as _data, temp_inits(inits) as _inits:
2285+
args = CmdStanArgs(
2286+
model_name=self._name,
2287+
model_exe=self._exe_file,
2288+
chain_ids=None,
2289+
method_args=diagnose_args,
2290+
data=_data,
2291+
seed=seed,
2292+
inits=_inits,
2293+
output_dir=output_dir,
2294+
sig_figs=sig_figs,
2295+
)
2296+
2297+
dummy_chain_id = 0
2298+
runset = RunSet(args=args, time_fmt=time_fmt)
2299+
self._run_cmdstan(
2300+
runset,
2301+
dummy_chain_id,
2302+
show_console=show_console,
2303+
timeout=timeout,
2304+
)
2305+
runset.raise_for_timeouts()
2306+
2307+
if not runset._check_retcodes():
2308+
if require_gradients_ok:
2309+
raise RuntimeError(
2310+
"The difference between autodiff and finite difference "
2311+
"gradients may exceed the error threshold. If you would "
2312+
"like to inspect the output, re-call with "
2313+
"`require_gradients_ok=False`."
2314+
)
2315+
get_logger().warning(
2316+
"The difference between autodiff and finite difference "
2317+
"gradients may exceed the error threshold. Proceeding because "
2318+
"`require_gradients_ok` is set to `False`."
2319+
)
2320+
2321+
return CmdStanDiagnose(runset)

cmdstanpy/stanfit/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .pathfinder import CmdStanPathfinder
2323
from .runset import RunSet
2424
from .vb import CmdStanVB
25+
from .diagnose import CmdStanDiagnose
2526

2627
__all__ = [
2728
"RunSet",
@@ -32,6 +33,7 @@
3233
"CmdStanGQ",
3334
"CmdStanLaplace",
3435
"CmdStanPathfinder",
36+
"CmdStanDiagnose",
3537
]
3638

3739

cmdstanpy/stanfit/diagnose.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Container for the result of running model diagnostics."""
2+
import io
3+
import re
4+
from typing import Dict, List, Optional
5+
6+
import pandas as pd
7+
8+
from ..cmdstan_args import Method
9+
from ..utils import scan_config
10+
11+
from .runset import RunSet
12+
13+
14+
class CmdStanDiagnose:
15+
"""
16+
Container for outputs from CmdStan diagnostics. Created by
17+
:meth:`CmdStanModel.diagnose`.
18+
"""
19+
20+
def __init__(self, runset: RunSet) -> None:
21+
if not runset.method == Method.DIAGNOSE:
22+
raise ValueError(
23+
"Wrong runset method, expecting diagnose runset, found method "
24+
f"{runset.method}."
25+
)
26+
self.runset = runset
27+
self.gradients_ok = runset._check_retcodes()
28+
29+
# Split the csv into header and gradient table parts.
30+
with open(self.runset.csv_files[0]) as handle:
31+
text = handle.read()
32+
header, table = re.split(r"#\s+Log probability=.*", text, re.M)
33+
self.config: Dict = {}
34+
scan_config(io.StringIO(header), self.config, 0)
35+
36+
# Remove comment characters, leading whitespace, and empty lines.
37+
lines: List[str] = []
38+
for line in table.splitlines():
39+
line = re.sub(r"^#\s+", "", line)
40+
if not line:
41+
continue
42+
# If this is the first line, remove whitespace from column names.
43+
if not lines:
44+
line = (
45+
line
46+
.replace("param idx", "param_idx")
47+
.replace("finite diff", "finite_diff")
48+
)
49+
lines.append(line)
50+
self._gradients = pd.read_csv(io.StringIO("\n".join(lines)), sep=r"\s+")
51+
52+
def __repr__(self) -> str:
53+
cmd = self.runset._args.method_args.compose(0, cmd=[])
54+
lines = [
55+
f"CmdStanDiagnose: model={self.runset.model}{cmd}",
56+
f"\tcsv_file: {self.runset.csv_files[0]}",
57+
f"\toutput_file: {self.runset.stdout_files[0]}",
58+
]
59+
if not self.gradients_ok:
60+
lines.append(
61+
"Warning: autodiff and finite difference gradients may not "
62+
"agree."
63+
)
64+
return "\n".join(lines)
65+
66+
@property
67+
def gradients(self) -> pd.DataFrame:
68+
"""
69+
Dataframe of parameter names, autodiff gradients, finite difference
70+
gradients and the delta between the two gradient estimates.
71+
"""
72+
return self._gradients
73+
74+
def save_csvfiles(self, dir: Optional[str] = None) -> None:
75+
"""
76+
Move output CSV files to specified directory. If files were
77+
written to the temporary session directory, clean filename.
78+
E.g., save 'bernoulli-201912081451-1-5nm6as7u.csv' as
79+
'bernoulli-201912081451-1.csv'.
80+
81+
:param dir: directory path
82+
83+
See Also
84+
--------
85+
stanfit.RunSet.save_csvfiles
86+
cmdstanpy.from_csv
87+
"""
88+
self.runset.save_csvfiles(dir)

test/test_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import List
1313
from unittest.mock import MagicMock, patch
1414

15+
import numpy as np
1516
import pytest
1617

1718
from cmdstanpy.model import CmdStanModel
@@ -593,3 +594,32 @@ def test_format_old_version() -> None:
593594
model.format(max_line_length=88)
594595

595596
model.format(canonicalize=True)
597+
598+
599+
def test_diagnose():
600+
# Check the gradients.
601+
model = CmdStanModel(stan_file=BERN_STAN)
602+
result = model.diagnose(data=BERN_DATA)
603+
604+
# Check we have the right columns.
605+
assert set(result.gradients) == {
606+
"param_idx",
607+
"value",
608+
"model",
609+
"finite_diff",
610+
"error",
611+
}
612+
assert result.gradients_ok
613+
614+
# Simulate bad gradients by using large finite difference.
615+
with pytest.raises(RuntimeError, match="may exceed the error threshold"):
616+
model.diagnose(data=BERN_DATA, epsilon=3)
617+
618+
# Check we get the results if we set require_gradients_ok=False.
619+
result = model.diagnose(
620+
data=BERN_DATA,
621+
epsilon=3,
622+
require_gradients_ok=False,
623+
)
624+
assert np.abs(result.gradients["error"]).max() > 1e-3
625+
assert not result.gradients_ok

0 commit comments

Comments
 (0)