|
8 | 8 | import shutil |
9 | 9 | import subprocess |
10 | 10 | import sys |
| 11 | +import tempfile |
11 | 12 | import threading |
12 | 13 | from collections import OrderedDict |
13 | 14 | from concurrent.futures import ThreadPoolExecutor |
|
17 | 18 | from pathlib import Path |
18 | 19 | from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union |
19 | 20 |
|
| 21 | +import pandas as pd |
20 | 22 | from tqdm.auto import tqdm |
21 | 23 |
|
22 | | -from cmdstanpy import _CMDSTAN_REFRESH, _CMDSTAN_SAMPLING, _CMDSTAN_WARMUP |
| 24 | +from cmdstanpy import ( |
| 25 | + _CMDSTAN_REFRESH, |
| 26 | + _CMDSTAN_SAMPLING, |
| 27 | + _CMDSTAN_WARMUP, |
| 28 | + _TMPDIR, |
| 29 | +) |
23 | 30 | from cmdstanpy.cmdstan_args import ( |
24 | 31 | CmdStanArgs, |
25 | 32 | GenerateQuantitiesArgs, |
@@ -1543,6 +1550,74 @@ def variational( |
1543 | 1550 | vb = CmdStanVB(runset) |
1544 | 1551 | return vb |
1545 | 1552 |
|
| 1553 | + def log_prob( |
| 1554 | + self, |
| 1555 | + params: Union[Dict[str, Any], str, os.PathLike], |
| 1556 | + data: Union[Mapping[str, Any], str, os.PathLike, None] = None, |
| 1557 | + ) -> pd.DataFrame: |
| 1558 | + """ |
| 1559 | + Calculate the log probability and gradient at the given parameter |
| 1560 | + values. |
| 1561 | +
|
| 1562 | + NOTE: This function is **NOT** an efficient way to evaluate the log |
| 1563 | + density of the model. It should be used for diagnostics ONLY. |
| 1564 | + Please, do not use this for other purposes such as testing new |
| 1565 | + sampling algorithms! |
| 1566 | +
|
| 1567 | + Parameters |
| 1568 | + ---------- |
| 1569 | + :param data: Values for all parameters in the model, specified |
| 1570 | + either as a dictionary with entries matching the parameter |
| 1571 | + variables, or as the path of a data file in JSON or Rdump format. |
| 1572 | +
|
| 1573 | + These should be given on the constrained (natural) scale. |
| 1574 | + :param data: Values for all data variables in the model, specified |
| 1575 | + either as a dictionary with entries matching the data variables, |
| 1576 | + or as the path of a data file in JSON or Rdump format. |
| 1577 | +
|
| 1578 | + :return: A pandas.DataFrame containing columns "lp_" and additional |
| 1579 | + columns for the gradient values. These gradients will be for the |
| 1580 | + unconstrained parameters of the model. |
| 1581 | + """ |
| 1582 | + |
| 1583 | + if cmdstan_version_before(2, 31, self.exe_info()): |
| 1584 | + raise ValueError( |
| 1585 | + "Method 'log_prob' not available for CmdStan versions " |
| 1586 | + "before 2.31" |
| 1587 | + ) |
| 1588 | + with MaybeDictToFilePath(data, params) as (_data, _params): |
| 1589 | + cmd = [ |
| 1590 | + str(self.exe_file), |
| 1591 | + "log_prob", |
| 1592 | + f"constrained_params={_params}", |
| 1593 | + ] |
| 1594 | + if _data is not None: |
| 1595 | + cmd += ["data", f"file={_data}"] |
| 1596 | + |
| 1597 | + output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR) |
| 1598 | + |
| 1599 | + output = os.path.join(output_dir, "output.csv") |
| 1600 | + cmd += ["output", f"file={output}"] |
| 1601 | + |
| 1602 | + get_logger().debug("Cmd: %s", str(cmd)) |
| 1603 | + |
| 1604 | + proc = subprocess.run( |
| 1605 | + cmd, capture_output=True, check=False, text=True |
| 1606 | + ) |
| 1607 | + if proc.returncode: |
| 1608 | + get_logger().error( |
| 1609 | + "'log_prob' command failed!\nstdout:%s\nstderr:%s", |
| 1610 | + proc.stdout, |
| 1611 | + proc.stderr, |
| 1612 | + ) |
| 1613 | + raise RuntimeError( |
| 1614 | + "Method 'log_prob' failed with return code " |
| 1615 | + + str(proc.returncode) |
| 1616 | + ) |
| 1617 | + |
| 1618 | + result = pd.read_csv(output, comment="#") |
| 1619 | + return result |
| 1620 | + |
1546 | 1621 | def _run_cmdstan( |
1547 | 1622 | self, |
1548 | 1623 | runset: RunSet, |
|
0 commit comments