Skip to content

Commit e9c4695

Browse files
committed
Minimal log_prob functionality
Relies on PR-1108
1 parent 1c761ff commit e9c4695

File tree

2 files changed

+119
-1
lines changed

2 files changed

+119
-1
lines changed

cmdstanpy/model.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import shutil
88
import subprocess
99
import sys
10+
import tempfile
1011
from collections import OrderedDict
1112
from concurrent.futures import ThreadPoolExecutor
1213
from datetime import datetime
@@ -15,10 +16,16 @@
1516
from pathlib import Path
1617
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
1718

19+
import pandas as pd
1820
import ujson as json
1921
from tqdm.auto import tqdm
2022

21-
from cmdstanpy import _CMDSTAN_REFRESH, _CMDSTAN_SAMPLING, _CMDSTAN_WARMUP
23+
from cmdstanpy import (
24+
_CMDSTAN_REFRESH,
25+
_CMDSTAN_SAMPLING,
26+
_CMDSTAN_WARMUP,
27+
_TMPDIR,
28+
)
2229
from cmdstanpy.cmdstan_args import (
2330
CmdStanArgs,
2431
GenerateQuantitiesArgs,
@@ -1536,6 +1543,77 @@ def variational(
15361543
vb = CmdStanVB(runset)
15371544
return vb
15381545

1546+
def log_prob(
1547+
self,
1548+
params: Union[Dict[str, Any], str, os.PathLike],
1549+
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
1550+
) -> pd.DataFrame:
1551+
"""
1552+
Calculate the log probability and gradient at the given parameter
1553+
values.
1554+
1555+
NOTE: This function is **NOT** an efficient way to evaluate the log
1556+
density of the model. It should be used for diagnostics ONLY.
1557+
Please, do not use this for other purposes such as testing new
1558+
sampling algorithms!
1559+
1560+
Parameters
1561+
----------
1562+
:param data: Values for all parameters in the model, specified
1563+
either as a dictionary with entries matching the parameter
1564+
variables, or as the path of a data file in JSON or Rdump format.
1565+
1566+
These should be given on the constrained (natural) scale.
1567+
:param data: Values for all data variables in the model, specified
1568+
either as a dictionary with entries matching the data variables,
1569+
or as the path of a data file in JSON or Rdump format.
1570+
1571+
:return: A pandas.DataFrame containing columns "lp_" and additional
1572+
columns for the gradient values. These gradients will be for the
1573+
unconstrained parameters of the model.
1574+
"""
1575+
1576+
if cmdstan_version_before(2, 31, self.exe_info()):
1577+
raise ValueError(
1578+
"Method 'log_prob' not available for CmdStan versions "
1579+
"before 2.31"
1580+
)
1581+
with MaybeDictToFilePath(data, params) as (_data, _params):
1582+
cmd = [
1583+
str(self.exe_file),
1584+
"log_prob",
1585+
f"constrained_params={_params}",
1586+
]
1587+
if _data is not None:
1588+
cmd += ["data", f"file={_data}"]
1589+
1590+
output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR)
1591+
1592+
output = os.path.join(output_dir, "output.csv")
1593+
cmd += ["output", f"file={output}"]
1594+
1595+
log_p = os.path.join(output_dir, "log_p.csv")
1596+
cmd += [f"log_prob_output_file={log_p}"]
1597+
1598+
get_logger().debug("Cmd: %s", str(cmd))
1599+
1600+
proc = subprocess.run(
1601+
cmd, capture_output=True, check=False, text=True
1602+
)
1603+
if proc.returncode:
1604+
get_logger().error(
1605+
"'log_prob' command failed!\nstdout:%s\nstderr:%s",
1606+
proc.stdout,
1607+
proc.stderr,
1608+
)
1609+
raise RuntimeError(
1610+
"Method 'log_prob' failed with return code "
1611+
+ str(proc.returncode)
1612+
)
1613+
1614+
result = pd.read_csv(log_p)
1615+
return result
1616+
15391617
def _run_cmdstan(
15401618
self,
15411619
runset: RunSet,

test/test_log_prob.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Tests for the `log_prob` method new in CmdStan 2.31.0"""
2+
3+
import logging
4+
import os
5+
from test import CustomTestCase
6+
7+
from testfixtures import LogCapture, StringComparison
8+
9+
from cmdstanpy.model import CmdStanModel
10+
from cmdstanpy.utils import EXTENSION
11+
12+
HERE = os.path.dirname(os.path.abspath(__file__))
13+
DATAFILES_PATH = os.path.join(HERE, 'data')
14+
15+
BERN_STAN = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
16+
BERN_DATA = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
17+
BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
18+
BERN_BASENAME = 'bernoulli'
19+
20+
21+
class CmdStanLogProb(CustomTestCase):
22+
def test_lp_good(self):
23+
model = CmdStanModel(stan_file=BERN_STAN)
24+
x = model.log_prob({"theta": 0.1}, data=BERN_DATA)
25+
assert "lp_" in x.columns
26+
27+
def test_lp_bad(self):
28+
model = CmdStanModel(stan_file=BERN_STAN)
29+
30+
with LogCapture(level=logging.ERROR) as log:
31+
with self.assertRaisesRegex(RuntimeError, "failed with returncode"):
32+
model.log_prob({"not_here": 0.1}, data=BERN_DATA)
33+
34+
log.check_present(
35+
(
36+
'cmdstanpy',
37+
'ERROR',
38+
StringComparison(r"(?s).*parameter theta not found.*"),
39+
)
40+
)

0 commit comments

Comments
 (0)