Skip to content

Commit 62cfc46

Browse files
committed
Return 1D array when parsing diagnoal hmc mass matrix
1 parent 810b32a commit 62cfc46

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919

2020
import numpy as np
21-
import numpy.typing as npt
2221
import pandas as pd
2322

2423
try:
@@ -442,7 +441,7 @@ def _assemble_draws(self) -> None:
442441
)
443442
self._step_size = np.empty(self.chains, dtype=np.float64)
444443

445-
mass_matrix_per_chain: List[Optional[npt.NDArray[np.float64]]] = []
444+
mass_matrix_per_chain = []
446445
for chain in range(self.chains):
447446
try:
448447
comments, draws = stancsv.parse_stan_csv_comments_and_draws(
@@ -465,13 +464,7 @@ def _assemble_draws(self) -> None:
465464
) from exc
466465

467466
if all(mm is not None for mm in mass_matrix_per_chain):
468-
if self.metric_type == "diag_e":
469-
# Mass matrix will have shape (1, num_params)
470-
self._metric = np.array(
471-
[mm[0] for mm in mass_matrix_per_chain] # type: ignore
472-
)
473-
else:
474-
self._metric = np.array(mass_matrix_per_chain)
467+
self._metric = np.array(mass_matrix_per_chain)
475468

476469
assert self._draws is not None
477470

cmdstanpy/utils/stancsv.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,17 @@ def parse_hmc_adaptation_lines(
9898
comment_lines: List[bytes],
9999
) -> Tuple[float, Optional[npt.NDArray[np.float64]]]:
100100
"""Extracts step size/mass matrix information from the Stan CSV comment
101-
lines by parsing the adaptation section. If unit metric is used, the mass
102-
matrix field will be None, otherwise an appropriate numpy array.
101+
lines by parsing the adaptation section. If the diag_e metric is used,
102+
the returned mass matrix will be a 1D array of the diagnoal elements,
103+
if the dense_e metric is used, it will be a 2D array representing the
104+
entire matrix, and if unit_e is used then None will be returned.
103105
104106
Returns a (step_size, mass_matrix) tuple"""
105107
step_size, mass_matrix = None, None
106108

107109
cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
108110
in_matrix_block = False
111+
diag_e_metric = False
109112
matrix_lines = []
110113
for line in cleaned_lines:
111114
if in_matrix_block and line.strip():
@@ -120,12 +123,16 @@ def parse_hmc_adaptation_lines(
120123
in_matrix_block = True
121124
elif line.startswith(b"No free"):
122125
break
126+
elif b"diag_e" in line:
127+
diag_e_metric = True
123128
if step_size is None:
124129
raise ValueError("Unable to parse adapated step size")
125130
if matrix_lines:
126131
mass_matrix = csv_bytes_list_to_numpy(
127132
matrix_lines, includes_header=False
128133
)
134+
if diag_e_metric and mass_matrix.shape[0] == 1:
135+
mass_matrix = mass_matrix[0]
129136
return step_size, mass_matrix
130137

131138

test/test_stancsv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def test_parsing_adaptation_lines():
212212

213213
def test_parsing_adaptation_lines_diagonal():
214214
lines = [
215+
b"diag_e", # Will be present in the Stan CSV config
215216
b"# Adaptation terminated\n",
216217
b"# Step size = 0.787025\n",
217218
b"# Diagonal elements of inverse mass matrix:\n",
@@ -220,7 +221,7 @@ def test_parsing_adaptation_lines_diagonal():
220221
step_size, mass_matrix = stancsv.parse_hmc_adaptation_lines(lines)
221222
assert step_size == 0.787025
222223
assert mass_matrix is not None
223-
assert np.array_equal(mass_matrix, np.array([[1, 2, 3]]))
224+
assert np.array_equal(mass_matrix, np.array([1, 2, 3]))
224225

225226

226227
def test_parsing_adaptation_lines_dense():

0 commit comments

Comments
 (0)