Skip to content

Commit 48b5a7e

Browse files
committed
Add diagnose method to CmdStanModel.
1 parent 54e67a1 commit 48b5a7e

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

cmdstanpy/model.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""CmdStanModel"""
22

3+
import io
34
import os
45
import platform
56
import re
@@ -2171,3 +2172,106 @@ def progress_hook(line: str, idx: int) -> None:
21712172
pbars[idx].postfix[0]["value"] = mline
21722173

21732174
return progress_hook
2175+
2176+
def diagnose(
2177+
self,
2178+
inits: Union[Dict[str, Any], str, os.PathLike, None] = None,
2179+
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
2180+
*,
2181+
epsilon: Optional[float] = None,
2182+
error: Optional[float] = None,
2183+
require_gradients_ok: bool = True,
2184+
sig_figs: Optional[int] = None,
2185+
) -> pd.DataFrame:
2186+
"""
2187+
Run diagnostics to calculate the gradients at the specified parameter
2188+
values and compare them with gradients calculated by finite differences.
2189+
2190+
:param inits: Specifies how the sampler initializes parameter values.
2191+
Initialization is either uniform random on a range centered on 0,
2192+
exactly 0, or a dictionary or file of initial values for some or
2193+
all parameters in the model. The default initialization behavior
2194+
will initialize all parameter values on range [-2, 2] on the
2195+
*unconstrained* support. The following value types are allowed:
2196+
* Single number, n > 0 - initialization range is [-n, n].
2197+
* 0 - all parameters are initialized to 0.
2198+
* dictionary - pairs parameter name : initial value.
2199+
* string - pathname to a JSON or Rdump data file.
2200+
2201+
:param data: Values for all data variables in the model, specified
2202+
either as a dictionary with entries matching the data variables,
2203+
or as the path of a data file in JSON or Rdump format.
2204+
2205+
:param sig_figs: Numerical precision used for output CSV and text files.
2206+
Must be an integer between 1 and 18. If unspecified, the default
2207+
precision for the system file I/O is used; the usual value is 6.
2208+
2209+
:param epsilon: Step size for finite difference gradients.
2210+
2211+
:param error: Absolute error threshold for comparing autodiff and finite
2212+
difference gradients.
2213+
2214+
:param require_gradients_ok: Whether or not to raise an error if Stan
2215+
reports that the difference between autodiff gradients and finite
2216+
difference gradients exceed the error threshold.
2217+
2218+
:return: A pandas.DataFrame containing columns
2219+
* "param_idx": increasing parameter index.
2220+
* "value": Parameter value.
2221+
* "model": Gradients evaluated using autodiff.
2222+
* "finite_diff": Gradients evaluated using finite differences.
2223+
* "error": Delta between autodiff and finite difference gradients.
2224+
"""
2225+
2226+
with temp_single_json(data) as _data, \
2227+
temp_single_json(inits) as _inits:
2228+
cmd = [
2229+
str(self.exe_file),
2230+
"diagnose",
2231+
"test=gradient",
2232+
]
2233+
if epsilon is not None:
2234+
cmd.append(f"epsilon={epsilon}")
2235+
if error is not None:
2236+
cmd.append(f"epsilon={error}")
2237+
if _data is not None:
2238+
cmd += ["data", f"file={_data}"]
2239+
if _inits is not None:
2240+
cmd.append(f"inits={_inits}")
2241+
2242+
output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR)
2243+
2244+
output = os.path.join(output_dir, "output.csv")
2245+
cmd += ["output", f"file={output}"]
2246+
if sig_figs is not None:
2247+
cmd.append(f"sig_figs={sig_figs}")
2248+
2249+
get_logger().debug("Cmd: %s", str(cmd))
2250+
2251+
proc = subprocess.run(
2252+
cmd, capture_output=True, check=False, text=True
2253+
)
2254+
if proc.returncode:
2255+
if require_gradients_ok:
2256+
raise RuntimeError(
2257+
"The difference between autodiff and finite difference "
2258+
"gradients may exceed the error threshold. If you "
2259+
"would like to inspect the output, re-call with "
2260+
"`require_gradients_ok=False`."
2261+
)
2262+
get_logger().warning(
2263+
"The difference between autodiff and finite difference "
2264+
"gradients may exceed the error threshold. Proceeding "
2265+
"because `require_gradients_ok` is set to `False`."
2266+
)
2267+
2268+
# Read the text and get the last chunk separated by a single # char.
2269+
with open(output) as handle:
2270+
text = handle.read()
2271+
*_, table = re.split(r"#\s*\n", text)
2272+
table = (
2273+
re.sub(r"^#\s*", "", table, flags=re.M)
2274+
.replace("param idx", "param_idx")
2275+
.replace("finite diff", "finite_diff")
2276+
)
2277+
return pd.read_csv(io.StringIO(table), sep=r"\s+")

test/test_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import List
1313
from unittest.mock import MagicMock, patch
1414

15+
import numpy as np
1516
import pytest
1617

1718
from cmdstanpy.model import CmdStanModel
@@ -596,3 +597,30 @@ def test_format_old_version() -> None:
596597
model.format(max_line_length=88)
597598

598599
model.format(canonicalize=True)
600+
601+
602+
def test_diagnose():
603+
# Check the gradients.
604+
model = CmdStanModel(stan_file=BERN_STAN)
605+
gradients = model.diagnose(data=BERN_DATA)
606+
607+
# Check we have the right columns.
608+
assert set(gradients) == {
609+
"param_idx",
610+
"value",
611+
"model",
612+
"finite_diff",
613+
"error",
614+
}
615+
616+
# Simulate bad gradients by using large finite difference.
617+
with pytest.raises(RuntimeError, match="may exceed the error threshold"):
618+
model.diagnose(data=BERN_DATA, epsilon=3)
619+
620+
# Check we get the results if we set require_gradients_ok=False.
621+
gradients = model.diagnose(
622+
data=BERN_DATA,
623+
epsilon=3,
624+
require_gradients_ok=False,
625+
)
626+
assert np.abs(gradients["error"]).max() > 1e-3

0 commit comments

Comments
 (0)