Skip to content

Commit 78a7fef

Browse files
authored
Merge pull request #637 from stan-dev/feature/log-prob
[CmdStan 2.31] Add `log_prob` function to model class
2 parents bf2ab47 + 71444c3 commit 78a7fef

File tree

2 files changed

+120
-1
lines changed

2 files changed

+120
-1
lines changed

cmdstanpy/model.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import shutil
99
import subprocess
1010
import sys
11+
import tempfile
1112
import threading
1213
from collections import OrderedDict
1314
from concurrent.futures import ThreadPoolExecutor
@@ -17,9 +18,15 @@
1718
from pathlib import Path
1819
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
1920

21+
import pandas as pd
2022
from tqdm.auto import tqdm
2123

22-
from cmdstanpy import _CMDSTAN_REFRESH, _CMDSTAN_SAMPLING, _CMDSTAN_WARMUP
24+
from cmdstanpy import (
25+
_CMDSTAN_REFRESH,
26+
_CMDSTAN_SAMPLING,
27+
_CMDSTAN_WARMUP,
28+
_TMPDIR,
29+
)
2330
from cmdstanpy.cmdstan_args import (
2431
CmdStanArgs,
2532
GenerateQuantitiesArgs,
@@ -1543,6 +1550,74 @@ def variational(
15431550
vb = CmdStanVB(runset)
15441551
return vb
15451552

1553+
def log_prob(
1554+
self,
1555+
params: Union[Dict[str, Any], str, os.PathLike],
1556+
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
1557+
) -> pd.DataFrame:
1558+
"""
1559+
Calculate the log probability and gradient at the given parameter
1560+
values.
1561+
1562+
NOTE: This function is **NOT** an efficient way to evaluate the log
1563+
density of the model. It should be used for diagnostics ONLY.
1564+
Please, do not use this for other purposes such as testing new
1565+
sampling algorithms!
1566+
1567+
Parameters
1568+
----------
1569+
:param data: Values for all parameters in the model, specified
1570+
either as a dictionary with entries matching the parameter
1571+
variables, or as the path of a data file in JSON or Rdump format.
1572+
1573+
These should be given on the constrained (natural) scale.
1574+
:param data: Values for all data variables in the model, specified
1575+
either as a dictionary with entries matching the data variables,
1576+
or as the path of a data file in JSON or Rdump format.
1577+
1578+
:return: A pandas.DataFrame containing columns "lp_" and additional
1579+
columns for the gradient values. These gradients will be for the
1580+
unconstrained parameters of the model.
1581+
"""
1582+
1583+
if cmdstan_version_before(2, 31, self.exe_info()):
1584+
raise ValueError(
1585+
"Method 'log_prob' not available for CmdStan versions "
1586+
"before 2.31"
1587+
)
1588+
with MaybeDictToFilePath(data, params) as (_data, _params):
1589+
cmd = [
1590+
str(self.exe_file),
1591+
"log_prob",
1592+
f"constrained_params={_params}",
1593+
]
1594+
if _data is not None:
1595+
cmd += ["data", f"file={_data}"]
1596+
1597+
output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR)
1598+
1599+
output = os.path.join(output_dir, "output.csv")
1600+
cmd += ["output", f"file={output}"]
1601+
1602+
get_logger().debug("Cmd: %s", str(cmd))
1603+
1604+
proc = subprocess.run(
1605+
cmd, capture_output=True, check=False, text=True
1606+
)
1607+
if proc.returncode:
1608+
get_logger().error(
1609+
"'log_prob' command failed!\nstdout:%s\nstderr:%s",
1610+
proc.stdout,
1611+
proc.stderr,
1612+
)
1613+
raise RuntimeError(
1614+
"Method 'log_prob' failed with return code "
1615+
+ str(proc.returncode)
1616+
)
1617+
1618+
result = pd.read_csv(output, comment="#")
1619+
return result
1620+
15461621
def _run_cmdstan(
15471622
self,
15481623
runset: RunSet,

test/test_log_prob.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Tests for the `log_prob` method new in CmdStan 2.31.0"""
2+
3+
import logging
4+
import os
5+
import re
6+
from test import check_present
7+
8+
import pytest
9+
10+
from cmdstanpy.model import CmdStanModel
11+
from cmdstanpy.utils import EXTENSION
12+
13+
HERE = os.path.dirname(os.path.abspath(__file__))
14+
DATAFILES_PATH = os.path.join(HERE, 'data')
15+
16+
BERN_STAN = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
17+
BERN_DATA = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
18+
BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
19+
BERN_BASENAME = 'bernoulli'
20+
21+
22+
def test_lp_good() -> None:
23+
model = CmdStanModel(stan_file=BERN_STAN)
24+
x = model.log_prob({"theta": 0.1}, data=BERN_DATA)
25+
assert "lp_" in x.columns
26+
27+
28+
def test_lp_bad(
29+
caplog: pytest.LogCaptureFixture,
30+
) -> None:
31+
model = CmdStanModel(stan_file=BERN_STAN)
32+
33+
with caplog.at_level(logging.ERROR):
34+
with pytest.raises(RuntimeError, match="failed with return code"):
35+
model.log_prob({"not_here": 0.1}, data=BERN_DATA)
36+
37+
check_present(
38+
caplog,
39+
(
40+
'cmdstanpy',
41+
'ERROR',
42+
re.compile(r"(?s).*parameter theta not found.*"),
43+
),
44+
)

0 commit comments

Comments
 (0)