Skip to content

Commit e6ac0c1

Browse files
committed
Add diagnose method to CmdStanModel.
1 parent 9a59a1a commit e6ac0c1

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

cmdstanpy/model.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""CmdStanModel"""
22

3+
import io
34
import os
45
import platform
56
import re
@@ -2203,3 +2204,106 @@ def progress_hook(line: str, idx: int) -> None:
22032204
pbars[idx].postfix[0]["value"] = mline
22042205

22052206
return progress_hook
2207+
2208+
def diagnose(
2209+
self,
2210+
inits: Union[Dict[str, Any], str, os.PathLike, None] = None,
2211+
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
2212+
*,
2213+
epsilon: Optional[float] = None,
2214+
error: Optional[float] = None,
2215+
require_gradients_ok: bool = True,
2216+
sig_figs: Optional[int] = None,
2217+
) -> pd.DataFrame:
2218+
"""
2219+
Run diagnostics to calculate the gradients at the specified parameter
2220+
values and compare them with gradients calculated by finite differences.
2221+
2222+
:param inits: Specifies how the sampler initializes parameter values.
2223+
Initialization is either uniform random on a range centered on 0,
2224+
exactly 0, or a dictionary or file of initial values for some or
2225+
all parameters in the model. The default initialization behavior
2226+
will initialize all parameter values on range [-2, 2] on the
2227+
*unconstrained* support. The following value types are allowed:
2228+
* Single number, n > 0 - initialization range is [-n, n].
2229+
* 0 - all parameters are initialized to 0.
2230+
* dictionary - pairs parameter name : initial value.
2231+
* string - pathname to a JSON or Rdump data file.
2232+
2233+
:param data: Values for all data variables in the model, specified
2234+
either as a dictionary with entries matching the data variables,
2235+
or as the path of a data file in JSON or Rdump format.
2236+
2237+
:param sig_figs: Numerical precision used for output CSV and text files.
2238+
Must be an integer between 1 and 18. If unspecified, the default
2239+
precision for the system file I/O is used; the usual value is 6.
2240+
2241+
:param epsilon: Step size for finite difference gradients.
2242+
2243+
:param error: Absolute error threshold for comparing autodiff and finite
2244+
difference gradients.
2245+
2246+
:param require_gradients_ok: Whether or not to raise an error if Stan
2247+
reports that the difference between autodiff gradients and finite
2248+
difference gradients exceed the error threshold.
2249+
2250+
:return: A pandas.DataFrame containing columns
2251+
* "param_idx": increasing parameter index.
2252+
* "value": Parameter value.
2253+
* "model": Gradients evaluated using autodiff.
2254+
* "finite_diff": Gradients evaluated using finite differences.
2255+
* "error": Delta between autodiff and finite difference gradients.
2256+
"""
2257+
2258+
with temp_single_json(data) as _data, \
2259+
temp_single_json(inits) as _inits:
2260+
cmd = [
2261+
str(self.exe_file),
2262+
"diagnose",
2263+
"test=gradient",
2264+
]
2265+
if epsilon is not None:
2266+
cmd.append(f"epsilon={epsilon}")
2267+
if error is not None:
2268+
cmd.append(f"epsilon={error}")
2269+
if _data is not None:
2270+
cmd += ["data", f"file={_data}"]
2271+
if _inits is not None:
2272+
cmd.append(f"inits={_inits}")
2273+
2274+
output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR)
2275+
2276+
output = os.path.join(output_dir, "output.csv")
2277+
cmd += ["output", f"file={output}"]
2278+
if sig_figs is not None:
2279+
cmd.append(f"sig_figs={sig_figs}")
2280+
2281+
get_logger().debug("Cmd: %s", str(cmd))
2282+
2283+
proc = subprocess.run(
2284+
cmd, capture_output=True, check=False, text=True
2285+
)
2286+
if proc.returncode:
2287+
if require_gradients_ok:
2288+
raise RuntimeError(
2289+
"The difference between autodiff and finite difference "
2290+
"gradients may exceed the error threshold. If you "
2291+
"would like to inspect the output, re-call with "
2292+
"`require_gradients_ok=False`."
2293+
)
2294+
get_logger().warning(
2295+
"The difference between autodiff and finite difference "
2296+
"gradients may exceed the error threshold. Proceeding "
2297+
"because `require_gradients_ok` is set to `False`."
2298+
)
2299+
2300+
# Read the text and get the last chunk separated by a single # char.
2301+
with open(output) as handle:
2302+
text = handle.read()
2303+
*_, table = re.split(r"#\s*\n", text)
2304+
table = (
2305+
re.sub(r"^#\s*", "", table, flags=re.M)
2306+
.replace("param idx", "param_idx")
2307+
.replace("finite diff", "finite_diff")
2308+
)
2309+
return pd.read_csv(io.StringIO(table), sep=r"\s+")

test/test_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import List
1313
from unittest.mock import MagicMock, patch
1414

15+
import numpy as np
1516
import pytest
1617

1718
from cmdstanpy.model import CmdStanModel
@@ -593,3 +594,30 @@ def test_format_old_version() -> None:
593594
model.format(max_line_length=88)
594595

595596
model.format(canonicalize=True)
597+
598+
599+
def test_diagnose():
600+
# Check the gradients.
601+
model = CmdStanModel(stan_file=BERN_STAN)
602+
gradients = model.diagnose(data=BERN_DATA)
603+
604+
# Check we have the right columns.
605+
assert set(gradients) == {
606+
"param_idx",
607+
"value",
608+
"model",
609+
"finite_diff",
610+
"error",
611+
}
612+
613+
# Simulate bad gradients by using large finite difference.
614+
with pytest.raises(RuntimeError, match="may exceed the error threshold"):
615+
model.diagnose(data=BERN_DATA, epsilon=3)
616+
617+
# Check we get the results if we set require_gradients_ok=False.
618+
gradients = model.diagnose(
619+
data=BERN_DATA,
620+
epsilon=3,
621+
require_gradients_ok=False,
622+
)
623+
assert np.abs(gradients["error"]).max() > 1e-3

0 commit comments

Comments
 (0)