Skip to content

Commit 650d2bb

Browse files
authored
Merge pull request #734 from tillahoffmann/diagnose
Add `diagnose` method to `CmdStanModel`.
2 parents 54e67a1 + 2b483b2 commit 650d2bb

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

cmdstanpy/model.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""CmdStanModel"""
22

3+
import io
34
import os
45
import platform
56
import re
@@ -2171,3 +2172,118 @@ def progress_hook(line: str, idx: int) -> None:
21712172
pbars[idx].postfix[0]["value"] = mline
21722173

21732174
return progress_hook
2175+
2176+
def diagnose(
2177+
self,
2178+
inits: Union[Dict[str, Any], str, os.PathLike, None] = None,
2179+
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
2180+
*,
2181+
epsilon: Optional[float] = None,
2182+
error: Optional[float] = None,
2183+
require_gradients_ok: bool = True,
2184+
sig_figs: Optional[int] = None,
2185+
) -> pd.DataFrame:
2186+
"""
2187+
Run diagnostics to calculate the gradients at the specified parameter
2188+
values and compare them with gradients calculated by finite differences.
2189+
2190+
:param inits: Specifies how the sampler initializes parameter values.
2191+
Initialization is either uniform random on a range centered on 0,
2192+
exactly 0, or a dictionary or file of initial values for some or
2193+
all parameters in the model. The default initialization behavior
2194+
will initialize all parameter values on range [-2, 2] on the
2195+
*unconstrained* support. The following value types are allowed:
2196+
* Single number, n > 0 - initialization range is [-n, n].
2197+
* 0 - all parameters are initialized to 0.
2198+
* dictionary - pairs parameter name : initial value.
2199+
* string - pathname to a JSON or Rdump data file.
2200+
2201+
:param data: Values for all data variables in the model, specified
2202+
either as a dictionary with entries matching the data variables,
2203+
or as the path of a data file in JSON or Rdump format.
2204+
2205+
:param sig_figs: Numerical precision used for output CSV and text files.
2206+
Must be an integer between 1 and 18. If unspecified, the default
2207+
precision for the system file I/O is used; the usual value is 6.
2208+
2209+
:param epsilon: Step size for finite difference gradients.
2210+
2211+
:param error: Absolute error threshold for comparing autodiff and finite
2212+
difference gradients.
2213+
2214+
:param require_gradients_ok: Whether or not to raise an error if Stan
2215+
reports that the difference between autodiff gradients and finite
2216+
difference gradients exceed the error threshold.
2217+
2218+
:return: A pandas.DataFrame containing columns
2219+
* "param_idx": increasing parameter index.
2220+
* "value": Parameter value.
2221+
* "model": Gradients evaluated using autodiff.
2222+
* "finite_diff": Gradients evaluated using finite differences.
2223+
* "error": Delta between autodiff and finite difference gradients.
2224+
2225+
Gradients are evaluated in the unconstrained space.
2226+
"""
2227+
2228+
with temp_single_json(data) as _data, \
2229+
temp_single_json(inits) as _inits:
2230+
cmd = [
2231+
str(self.exe_file),
2232+
"diagnose",
2233+
"test=gradient",
2234+
]
2235+
if epsilon is not None:
2236+
cmd.append(f"epsilon={epsilon}")
2237+
if error is not None:
2238+
cmd.append(f"epsilon={error}")
2239+
if _data is not None:
2240+
cmd += ["data", f"file={_data}"]
2241+
if _inits is not None:
2242+
cmd.append(f"init={_inits}")
2243+
2244+
output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR)
2245+
2246+
output = os.path.join(output_dir, "output.csv")
2247+
cmd += ["output", f"file={output}"]
2248+
if sig_figs is not None:
2249+
cmd.append(f"sig_figs={sig_figs}")
2250+
2251+
get_logger().debug("Cmd: %s", str(cmd))
2252+
2253+
proc = subprocess.run(
2254+
cmd, capture_output=True, check=False, text=True
2255+
)
2256+
if proc.returncode:
2257+
get_logger().error(
2258+
"'diagnose' command failed!\nstdout:%s\nstderr:%s",
2259+
proc.stdout,
2260+
proc.stderr,
2261+
)
2262+
if require_gradients_ok:
2263+
raise RuntimeError(
2264+
"The difference between autodiff and finite difference "
2265+
"gradients may exceed the error threshold. If you "
2266+
"would like to inspect the output, re-call with "
2267+
"`require_gradients_ok=False`."
2268+
)
2269+
get_logger().warning(
2270+
"The difference between autodiff and finite difference "
2271+
"gradients may exceed the error threshold. Proceeding "
2272+
"because `require_gradients_ok` is set to `False`."
2273+
)
2274+
2275+
# Read the text and get the last chunk separated by a single # char.
2276+
try:
2277+
with open(output) as handle:
2278+
text = handle.read()
2279+
except FileNotFoundError as exc:
2280+
raise RuntimeError(
2281+
"Output of 'diagnose' command does not exist."
2282+
) from exc
2283+
*_, table = re.split(r"#\s*\n", text)
2284+
table = (
2285+
re.sub(r"^#\s*", "", table, flags=re.M)
2286+
.replace("param idx", "param_idx")
2287+
.replace("finite diff", "finite_diff")
2288+
)
2289+
return pd.read_csv(io.StringIO(table), sep=r"\s+")

test/test_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import List
1313
from unittest.mock import MagicMock, patch
1414

15+
import numpy as np
1516
import pytest
1617

1718
from cmdstanpy.model import CmdStanModel
@@ -596,3 +597,35 @@ def test_format_old_version() -> None:
596597
model.format(max_line_length=88)
597598

598599
model.format(canonicalize=True)
600+
601+
602+
def test_diagnose():
603+
# Check the gradients.
604+
model = CmdStanModel(stan_file=BERN_STAN)
605+
gradients = model.diagnose(data=BERN_DATA)
606+
607+
# Check we have the right columns.
608+
assert set(gradients) == {
609+
"param_idx",
610+
"value",
611+
"model",
612+
"finite_diff",
613+
"error",
614+
}
615+
616+
# Check gradients against the same value as in `log_prob`.
617+
inits = {"theta": 0.34903938392023830482}
618+
gradients = model.diagnose(data=BERN_DATA, inits=inits)
619+
np.testing.assert_allclose(gradients.model.iloc[0], -1.18847)
620+
621+
# Simulate bad gradients by using large finite difference.
622+
with pytest.raises(RuntimeError, match="may exceed the error threshold"):
623+
model.diagnose(data=BERN_DATA, epsilon=3)
624+
625+
# Check we get the results if we set require_gradients_ok=False.
626+
gradients = model.diagnose(
627+
data=BERN_DATA,
628+
epsilon=3,
629+
require_gradients_ok=False,
630+
)
631+
assert np.abs(gradients["error"]).max() > 1e-3

0 commit comments

Comments
 (0)