Skip to content

Commit a2da636

Browse files
authored
Merge pull request #844 from amas0/enable-save-metric
Enable `save_metric=1` and sources MCMC metric info from new JSON file
2 parents 6845e3f + 915eef5 commit a2da636

21 files changed

+518
-363
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ def compose(self, idx: int, cmd: list[str]) -> list[str]:
273273
cmd.append(f'window={self.adapt_metric_window}')
274274
if self.adapt_step_size is not None:
275275
cmd.append('term_buffer={}'.format(self.adapt_step_size))
276+
if self.adapt_engaged:
277+
cmd.append('save_metric=1')
278+
# End adapt subsection
279+
276280
if self.num_chains > 1:
277281
cmd.append('num_chains={}'.format(self.num_chains))
278282

cmdstanpy/stanfit/gq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,18 +423,18 @@ def draws_xr(
423423

424424
@overload
425425
def draws_xr(
426-
self: "CmdStanGQ[CmdStanMCMC]",
426+
self: CmdStanGQ[CmdStanMCMC],
427427
vars: str | list[str] | None = None,
428428
inc_warmup: bool = False,
429429
inc_sample: bool = False,
430-
) -> "xr.Dataset": ...
430+
) -> xr.Dataset: ...
431431

432432
def draws_xr(
433433
self,
434434
vars: str | list[str] | None = None,
435435
inc_warmup: bool = False,
436436
inc_sample: bool = False,
437-
) -> "xr.Dataset":
437+
) -> xr.Dataset:
438438
"""
439439
Returns the generated quantities draws as a xarray Dataset.
440440

cmdstanpy/stanfit/laplace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Container for the result of running a laplace approximation.
33
"""
44

5+
from __future__ import annotations
6+
57
from typing import Any, Hashable, MutableMapping
68

79
import numpy as np
@@ -197,7 +199,7 @@ def draws_pd(
197199
def draws_xr(
198200
self,
199201
vars: str | list[str] | None = None,
200-
) -> "xr.Dataset":
202+
) -> xr.Dataset:
201203
"""
202204
Returns the sampler draws as a xarray Dataset.
203205

cmdstanpy/stanfit/mcmc.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Container for the result of running the sample (MCMC) method
33
"""
44

5+
from __future__ import annotations
6+
57
import math
68
import os
79
from io import StringIO
@@ -31,7 +33,7 @@
3133
stancsv,
3234
)
3335

34-
from .metadata import InferenceMetadata
36+
from .metadata import InferenceMetadata, MetricInfo
3537
from .runset import RunSet
3638

3739

@@ -81,6 +83,7 @@ def __init__(
8183
# info from CSV values, instantiated lazily
8284
self._draws: np.ndarray = np.array(())
8385
# only valid when not is_fixed_param
86+
self._metric_type: str | None = None
8487
self._metric: np.ndarray = np.array(())
8588
self._step_size: np.ndarray = np.array(())
8689
self._divergences: np.ndarray = np.zeros(self.runset.chains, dtype=int)
@@ -92,6 +95,8 @@ def __init__(
9295
# info from CSV header and initial and final comment blocks
9396
config = self._validate_csv_files()
9497
self._metadata: InferenceMetadata = InferenceMetadata(config)
98+
self._chain_metric_info: list[MetricInfo] = []
99+
95100
if not self._is_fixed_param:
96101
self._check_sampler_diagnostics()
97102

@@ -216,11 +221,13 @@ def metric_type(self) -> str | None:
216221
to CmdStan arg 'metric'.
217222
When sampler algorithm 'fixed_param' is specified, metric_type is None.
218223
"""
219-
return (
220-
self._metadata.cmdstan_config['metric']
221-
if not self._is_fixed_param
222-
else None
223-
)
224+
if self._is_fixed_param:
225+
return None
226+
227+
if self._metric_type is None:
228+
self._parse_metric_info()
229+
230+
return self._metric_type
224231

225232
@property
226233
def inv_metric(self) -> np.ndarray | None:
@@ -230,10 +237,15 @@ def inv_metric(self) -> np.ndarray | None:
230237
a ``nchains x nparams x nparams`` array when metric_type is 'dense_e',
231238
or ``None`` when metric_type is 'unit_e' or algorithm is 'fixed_param'.
232239
"""
233-
if self._is_fixed_param or self.metric_type == 'unit_e':
240+
if self._is_fixed_param:
241+
return None
242+
243+
if self._metric_type is None:
244+
self._parse_metric_info()
245+
246+
if self.metric_type == 'unit_e':
234247
return None
235248

236-
self._assemble_draws()
237249
return self._metric
238250

239251
@property
@@ -242,8 +254,13 @@ def step_size(self) -> np.ndarray | None:
242254
Step size used by sampler for each chain.
243255
When sampler algorithm 'fixed_param' is specified, step size is None.
244256
"""
245-
self._assemble_draws()
246-
return self._step_size if not self._is_fixed_param else None
257+
if self._is_fixed_param:
258+
return None
259+
260+
if self._metric_type is None:
261+
self._parse_metric_info()
262+
263+
return self._step_size
247264

248265
@property
249266
def thin(self) -> int:
@@ -382,6 +399,27 @@ def _validate_csv_files(self) -> dict[str, Any]:
382399
self._max_treedepths[i] = drest['ct_max_treedepth']
383400
return dzero
384401

402+
def _parse_metric_info(self) -> None:
403+
"""Extracts metric type, inv_metric, and step size information from the
404+
parsed metric JSONs."""
405+
self._chain_metric_info = []
406+
for mf in self.runset.metric_files:
407+
with open(mf) as f:
408+
self._chain_metric_info.append(
409+
MetricInfo.model_validate_json(f.read())
410+
)
411+
412+
metric_types = {cmi.metric_type for cmi in self._chain_metric_info}
413+
if len(metric_types) != 1:
414+
raise ValueError("Inconsistent metric types found across chains")
415+
self._metric_type = self._chain_metric_info[0].metric_type
416+
self._metric = np.asarray(
417+
[cmi.inv_metric for cmi in self._chain_metric_info]
418+
)
419+
self._step_size = np.asarray(
420+
[cmi.stepsize for cmi in self._chain_metric_info]
421+
)
422+
385423
def _check_sampler_diagnostics(self) -> None:
386424
"""
387425
Warn if any iterations ended in divergences or hit maxtreedepth.
@@ -424,13 +462,11 @@ def _assemble_draws(self) -> None:
424462
dtype=np.float64,
425463
order='F',
426464
)
427-
self._step_size = np.empty(self.chains, dtype=np.float64)
428465

429-
mass_matrix_per_chain = []
430466
for chain in range(self.chains):
431467
try:
432468
(
433-
comments,
469+
_,
434470
header,
435471
draws,
436472
) = stancsv.parse_comments_header_and_draws(
@@ -443,20 +479,11 @@ def _assemble_draws(self) -> None:
443479
draws_np = np.empty((0, n_cols))
444480

445481
self._draws[:, chain, :] = draws_np
446-
if not self._is_fixed_param:
447-
(
448-
self._step_size[chain],
449-
mass_matrix,
450-
) = stancsv.parse_hmc_adaptation_lines(comments)
451-
mass_matrix_per_chain.append(mass_matrix)
452482
except Exception as exc:
453483
raise ValueError(
454484
f"Parsing output from {self.runset.csv_files[chain]} failed"
455485
) from exc
456486

457-
if all(mm is not None for mm in mass_matrix_per_chain):
458-
self._metric = np.array(mass_matrix_per_chain)
459-
460487
assert self._draws is not None
461488

462489
def summary(
@@ -652,7 +679,7 @@ def draws_pd(
652679

653680
def draws_xr(
654681
self, vars: str | list[str] | None = None, inc_warmup: bool = False
655-
) -> "xr.Dataset":
682+
) -> xr.Dataset:
656683
"""
657684
Returns the sampler draws as a xarray Dataset.
658685

cmdstanpy/stanfit/metadata.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Container for metadata parsed from the output of a CmdStan run"""
22

3+
from __future__ import annotations
4+
35
import copy
6+
import math
47
import os
5-
from typing import Any, Iterator
8+
from typing import Any, Iterator, Literal
69

710
import stanio
11+
from pydantic import BaseModel, field_validator, model_validator
812

913
from cmdstanpy.utils import stancsv
1014

@@ -34,7 +38,7 @@ def __init__(
3438
@classmethod
3539
def from_csv(
3640
cls, stan_csv: str | os.PathLike | Iterator[bytes]
37-
) -> 'InferenceMetadata':
41+
) -> InferenceMetadata:
3842
try:
3943
comments, header, _ = stancsv.parse_comments_header_and_draws(
4044
stan_csv
@@ -79,3 +83,45 @@ def stan_vars(self) -> dict[str, stanio.Variable]:
7983
These are the user-defined variables in the Stan program.
8084
"""
8185
return self._stan_vars
86+
87+
88+
class MetricInfo(BaseModel):
89+
"""Structured representation of HMC-NUTS metric information,
90+
as output by CmdStan"""
91+
92+
stepsize: float
93+
metric_type: Literal["diag_e", "dense_e", "unit_e"]
94+
inv_metric: list[float] | list[list[float]]
95+
96+
@field_validator("stepsize")
97+
@classmethod
98+
def validate_stepsize(cls, v: float) -> float:
99+
if not math.isnan(v) and v <= 0:
100+
raise ValueError("stepsize must be greater than 0 or NaN")
101+
return v
102+
103+
@model_validator(mode="after")
104+
def validate_inv_metric_shape(self) -> MetricInfo:
105+
if not self.inv_metric: # Empty inv_metric, e.g. from no parameters
106+
return self
107+
108+
is_1d = isinstance(self.inv_metric[0], float)
109+
110+
if self.metric_type in ("diag_e", "unit_e") and not is_1d:
111+
raise ValueError(
112+
"inv_metric must be 1D for diag_e and unit_e metric type"
113+
)
114+
if self.metric_type == "dense_e":
115+
if is_1d:
116+
raise ValueError("Dense inv_metric must be 2D")
117+
118+
if any(not row for row in self.inv_metric):
119+
raise ValueError("Dense inv_metric cannot contain empty rows")
120+
121+
n_rows = len(self.inv_metric)
122+
if not all(
123+
len(row) == n_rows for row in self.inv_metric # type: ignore
124+
):
125+
raise ValueError("Dense inv_metric must be square")
126+
127+
return self

cmdstanpy/stanfit/runset.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
self._stdout_files, self._profile_files = [], []
5858
self._csv_files, self._diagnostic_files = [], []
5959
self._config_files = []
60+
self._metric_files = []
6061

6162
# per-process output files
6263
if one_process_per_chain and chains > 1:
@@ -87,6 +88,10 @@ def __init__(
8788
# per-chain output files
8889
if chains == 1:
8990
self._csv_files = [self.gen_file_name(".csv")]
91+
if args.method == Method.SAMPLE:
92+
self._metric_files = [
93+
self.gen_file_name(".json", extra="metric")
94+
]
9095
if args.save_latent_dynamics:
9196
self._diagnostic_files = [
9297
self.gen_file_name(".csv", extra="diagnostic")
@@ -95,6 +100,20 @@ def __init__(
95100
self._csv_files = [
96101
self.gen_file_name(".csv", id=id) for id in self._chain_ids
97102
]
103+
if args.method == Method.SAMPLE:
104+
if one_process_per_chain:
105+
self._metric_files = [
106+
os.path.join(
107+
self._outdir,
108+
f"{self._base_outfile}_{id}_metric.json",
109+
)
110+
for id in self._chain_ids
111+
]
112+
else:
113+
self._metric_files = [
114+
self.gen_file_name(".json", extra="metric", id=id)
115+
for id in self._chain_ids
116+
]
98117
if args.save_latent_dynamics:
99118
self._diagnostic_files = [
100119
self.gen_file_name(".csv", extra="diagnostic", id=id)
@@ -222,6 +241,11 @@ def profile_files(self) -> list[str]:
222241
"""List of paths to CmdStan profiler files."""
223242
return self._profile_files
224243

244+
@property
245+
def metric_files(self) -> list[str]:
246+
"""List of paths to CmdStan NUTS-HMC sampler metric files."""
247+
return self._metric_files
248+
225249
def gen_file_name(
226250
self, suffix: str, *, extra: str = "", id: int | None = None
227251
) -> str:

cmdstanpy/utils/logging.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
CmdStanPy logging
33
"""
44

5+
from __future__ import annotations
6+
57
import functools
68
import logging
79
import types
@@ -39,7 +41,7 @@ def __init__(self, disable: bool) -> None:
3941
def __repr__(self) -> str:
4042
return ""
4143

42-
def __enter__(self) -> "ToggleLogging":
44+
def __enter__(self) -> ToggleLogging:
4345
self.logger.disabled = self.disable
4446
return self
4547

0 commit comments

Comments
 (0)