Skip to content

Commit 5383424

Browse files
authored
Merge pull request #799 from amas0/faster-mcmc-csv-parsing
Implementation of faster MCMC CSV parsing and Stan CSV utilities
2 parents 3a0dea8 + 4d3cbf7 commit 5383424

File tree

5 files changed

+490
-70
lines changed

5 files changed

+490
-70
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 28 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
do_command,
4040
flatten_chains,
4141
get_logger,
42+
stancsv,
4243
)
4344

4445
from .metadata import InferenceMetadata
@@ -429,83 +430,42 @@ def _assemble_draws(self) -> None:
429430
"""
430431
if self._draws.shape != (0,):
431432
return
433+
432434
num_draws = self.num_draws_sampling
433-
sampling_iter_start = 0
434435
if self._save_warmup:
435436
num_draws += self.num_draws_warmup
436-
sampling_iter_start = self.num_draws_warmup
437437
self._draws = np.empty(
438438
(num_draws, self.chains, len(self.column_names)),
439-
dtype=float,
439+
dtype=np.float64,
440440
order='F',
441441
)
442-
self._step_size = np.empty(self.chains, dtype=float)
442+
self._step_size = np.empty(self.chains, dtype=np.float64)
443+
444+
mass_matrix_per_chain = []
443445
for chain in range(self.chains):
444-
with open(self.runset.csv_files[chain], 'r') as fd:
445-
line = fd.readline().strip()
446-
# read initial comments, CSV header row
447-
while len(line) > 0 and line.startswith('#'):
448-
line = fd.readline().strip()
446+
try:
447+
comments, draws = stancsv.parse_stan_csv_comments_and_draws(
448+
self.runset.csv_files[chain]
449+
)
450+
451+
self._draws[:, chain, :] = stancsv.csv_bytes_list_to_numpy(
452+
draws
453+
)
454+
449455
if not self._is_fixed_param:
450-
# handle warmup draws, if any
451-
if self._save_warmup:
452-
for i in range(self.num_draws_warmup):
453-
line = fd.readline().strip()
454-
xs = line.split(',')
455-
self._draws[i, chain, :] = [float(x) for x in xs]
456-
line = fd.readline().strip()
457-
if line != '# Adaptation terminated': # shouldn't happen?
458-
while line != '# Adaptation terminated':
459-
line = fd.readline().strip()
460-
# step_size, metric (diag_e and dense_e only)
461-
line = fd.readline().strip()
462-
_, step_size = line.split('=')
463-
self._step_size[chain] = float(step_size.strip())
464-
if self._metadata.cmdstan_config['metric'] != 'unit_e':
465-
line = fd.readline().strip() # metric type
466-
line = fd.readline().lstrip(' #\t').rstrip()
467-
num_unconstrained_params = len(line.split(','))
468-
if chain == 0: # can't allocate w/o num params
469-
if self.metric_type == 'diag_e':
470-
self._metric = np.empty(
471-
(self.chains, num_unconstrained_params),
472-
dtype=float,
473-
)
474-
else:
475-
self._metric = np.empty(
476-
(
477-
self.chains,
478-
num_unconstrained_params,
479-
num_unconstrained_params,
480-
),
481-
dtype=float,
482-
)
483-
if line:
484-
if self.metric_type == 'diag_e':
485-
xs = line.split(',')
486-
self._metric[chain, :] = [float(x) for x in xs]
487-
else:
488-
xs = line.strip().split(',')
489-
self._metric[chain, 0, :] = [
490-
float(x) for x in xs
491-
]
492-
for i in range(1, num_unconstrained_params):
493-
line = fd.readline().lstrip(' #\t').rstrip()
494-
xs = line.split(',')
495-
self._metric[chain, i, :] = [
496-
float(x) for x in xs
497-
]
498-
else: # unit_e changed in 2.34 to have an extra line
499-
pos = fd.tell()
500-
line = fd.readline().strip()
501-
if not line.startswith('#'):
502-
fd.seek(pos)
503-
504-
# process draws
505-
for i in range(sampling_iter_start, num_draws):
506-
line = fd.readline().strip()
507-
xs = line.split(',')
508-
self._draws[i, chain, :] = [float(x) for x in xs]
456+
(
457+
self._step_size[chain],
458+
mass_matrix,
459+
) = stancsv.parse_hmc_adaptation_lines(comments)
460+
mass_matrix_per_chain.append(mass_matrix)
461+
except Exception as exc:
462+
raise ValueError(
463+
f"Parsing output from {self.runset.csv_files[chain]} failed"
464+
) from exc
465+
466+
if all(mm is not None for mm in mass_matrix_per_chain):
467+
self._metric = np.array(mass_matrix_per_chain)
468+
509469
assert self._draws is not None
510470

511471
def summary(

cmdstanpy/utils/stancsv.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,146 @@
11
"""
22
Utility functions for reading the Stan CSV format
33
"""
4+
5+
import io
46
import json
57
import math
8+
import os
69
import re
7-
from typing import Any, Dict, List, MutableMapping, Optional, TextIO, Union
10+
import warnings
11+
from typing import (
12+
Any,
13+
Dict,
14+
Iterator,
15+
List,
16+
MutableMapping,
17+
Optional,
18+
TextIO,
19+
Tuple,
20+
Union,
21+
)
822

923
import numpy as np
24+
import numpy.typing as npt
1025
import pandas as pd
1126

1227
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP
1328

1429

30+
def parse_stan_csv_comments_and_draws(
31+
stan_csv: Union[str, os.PathLike, Iterator[bytes]],
32+
) -> Tuple[List[bytes], List[bytes]]:
33+
"""Parses lines of a Stan CSV file into comment lines and draws lines, where
34+
a draws line is just a non-commented line.
35+
36+
Returns a (comment_lines, draws_lines) tuple.
37+
"""
38+
39+
def split_comments_and_draws(
40+
lines: Iterator[bytes],
41+
) -> Tuple[List[bytes], List[bytes]]:
42+
comment_lines, draws_lines = [], []
43+
for line in lines:
44+
if line.startswith(b"#"): # is comment line
45+
comment_lines.append(line)
46+
else:
47+
draws_lines.append(line)
48+
return comment_lines, draws_lines
49+
50+
if isinstance(stan_csv, (str, os.PathLike)):
51+
with open(stan_csv, "rb") as f:
52+
return split_comments_and_draws(f)
53+
else:
54+
return split_comments_and_draws(stan_csv)
55+
56+
57+
def csv_bytes_list_to_numpy(
58+
csv_bytes_list: List[bytes], includes_header: bool = True
59+
) -> npt.NDArray[np.float64]:
60+
"""Efficiently converts a list of bytes representing whose concatenation
61+
represents a CSV file into a numpy array. Includes header specifies
62+
whether the bytes contains an initial header line."""
63+
try:
64+
import polars as pl
65+
66+
try:
67+
if not csv_bytes_list:
68+
raise ValueError("No data found to parse")
69+
num_cols = csv_bytes_list[0].count(b",") + 1
70+
out: npt.NDArray[np.float64] = (
71+
pl.read_csv(
72+
io.BytesIO(b"".join(csv_bytes_list)),
73+
has_header=includes_header,
74+
schema_overrides=[pl.Float64] * num_cols,
75+
infer_schema=False,
76+
)
77+
.to_numpy()
78+
.astype(np.float64)
79+
)
80+
if out.shape[0] == 0:
81+
raise ValueError("No data found to parse")
82+
except pl.exceptions.NoDataError as exc:
83+
raise ValueError("No data found to parse") from exc
84+
except ImportError:
85+
with warnings.catch_warnings():
86+
warnings.filterwarnings("ignore")
87+
out = np.loadtxt(
88+
csv_bytes_list,
89+
skiprows=int(includes_header),
90+
delimiter=",",
91+
dtype=np.float64,
92+
ndmin=1,
93+
)
94+
if out.shape == (0,):
95+
raise ValueError("No data found to parse") # pylint: disable=W0707
96+
if len(out.shape) == 1:
97+
out = out.reshape(1, -1)
98+
99+
return out
100+
101+
102+
def parse_hmc_adaptation_lines(
103+
comment_lines: List[bytes],
104+
) -> Tuple[float, Optional[npt.NDArray[np.float64]]]:
105+
"""Extracts step size/mass matrix information from the Stan CSV comment
106+
lines by parsing the adaptation section. If the diag_e metric is used,
107+
the returned mass matrix will be a 1D array of the diagnoal elements,
108+
if the dense_e metric is used, it will be a 2D array representing the
109+
entire matrix, and if unit_e is used then None will be returned.
110+
111+
Returns a (step_size, mass_matrix) tuple"""
112+
step_size, mass_matrix = None, None
113+
114+
cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
115+
in_matrix_block = False
116+
diag_e_metric = False
117+
matrix_lines = []
118+
for line in cleaned_lines:
119+
if in_matrix_block and line.strip():
120+
# Stop when we get to timing block
121+
if line.startswith(b"Elapsed Time"):
122+
break
123+
matrix_lines.append(line)
124+
elif line.startswith(b"Step size"):
125+
_, ss_str = line.split(b" = ")
126+
step_size = float(ss_str)
127+
elif line.startswith(b"Diagonal") or line.startswith(b"Elements"):
128+
in_matrix_block = True
129+
elif line.startswith(b"No free"):
130+
break
131+
elif b"diag_e" in line:
132+
diag_e_metric = True
133+
if step_size is None:
134+
raise ValueError("Unable to parse adapated step size")
135+
if matrix_lines:
136+
mass_matrix = csv_bytes_list_to_numpy(
137+
matrix_lines, includes_header=False
138+
)
139+
if diag_e_metric and mass_matrix.shape[0] == 1:
140+
mass_matrix = mass_matrix[0]
141+
return step_size, mass_matrix
142+
143+
15144
def check_sampler_csv(
16145
path: str,
17146
is_fixed_param: bool = False,

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ packages = ["cmdstanpy", "cmdstanpy.stanfit", "cmdstanpy.utils"]
4040
"cmdstanpy" = ["py.typed"]
4141

4242
[project.optional-dependencies]
43-
all = ["xarray"]
43+
all = ["xarray", "polars>=1.8.2"]
4444
test = [
4545
"flake8",
4646
"pylint",
@@ -49,6 +49,7 @@ test = [
4949
"pytest-order",
5050
"mypy",
5151
"xarray",
52+
"polars>=1.8.2"
5253
]
5354
docs = [
5455
"sphinx>5,<6",

test/test_sample.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def test_bernoulli_unit_e(
204204
show_progress=False,
205205
)
206206
assert bern_fit.metric_type == 'unit_e'
207+
assert bern_fit.metric is None
207208
assert bern_fit.step_size.shape == (2,)
208209
with caplog.at_level(logging.INFO):
209210
logging.getLogger()
@@ -2127,3 +2128,13 @@ def test_mcmc_init_sampling():
21272128

21282129
assert fit.chains == 4
21292130
assert fit.draws().shape == (1000, 4, 9)
2131+
2132+
2133+
def test_sample_dense_mass_matrix():
2134+
stan = os.path.join(DATAFILES_PATH, 'linear_regression.stan')
2135+
jdata = os.path.join(DATAFILES_PATH, 'linear_regression.data.json')
2136+
linear_model = CmdStanModel(stan_file=stan)
2137+
2138+
fit = linear_model.sample(data=jdata, metric="dense_e", chains=2)
2139+
assert fit.metric is not None
2140+
assert fit.metric.shape == (2, 3, 3)

0 commit comments

Comments
 (0)