Skip to content

Commit a79f0c2

Browse files
committed
Add inits and explicit gradient test.
1 parent 1ae60a9 commit a79f0c2

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

cmdstanpy/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2221,6 +2221,8 @@ def diagnose(
22212221
* "model": Gradients evaluated using autodiff.
22222222
* "finite_diff": Gradients evaluated using finite differences.
22232223
* "error": Delta between autodiff and finite difference gradients.
2224+
2225+
Gradients are evaluated in the unconstrained space.
22242226
"""
22252227

22262228
with temp_single_json(data) as _data, \
@@ -2237,7 +2239,7 @@ def diagnose(
22372239
if _data is not None:
22382240
cmd += ["data", f"file={_data}"]
22392241
if _inits is not None:
2240-
cmd.append(f"inits={_inits}")
2242+
cmd.append(f"init={_inits}")
22412243

22422244
output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR)
22432245

test/test_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,11 @@ def test_diagnose():
610610
"error",
611611
}
612612

613+
# Check gradients against the same value as in `log_prob`.
614+
inits = {"theta": 0.34903938392023830482}
615+
gradients = model.diagnose(data=BERN_DATA, inits=inits)
616+
np.testing.assert_allclose(gradients.model.iloc[0], -1.18847)
617+
613618
# Simulate bad gradients by using large finite difference.
614619
with pytest.raises(RuntimeError, match="may exceed the error threshold"):
615620
model.diagnose(data=BERN_DATA, epsilon=3)

0 commit comments

Comments
 (0)