Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c06ce85
Add StanCsvMCMC dataclass
amas0 Jul 19, 2025
c21494d
Filter out empty mass matrix lines
amas0 Jul 19, 2025
69ad018
Update _assemble_draws to use StanCsvMCMC object
amas0 Jul 19, 2025
2ff7c01
Fix code incompatible with Python 3.8
amas0 Jul 20, 2025
9337738
Convert draws parsing to polars
amas0 Jul 20, 2025
df1162f
Add docstrings
amas0 Jul 22, 2025
75de3b5
Add initial unit tests
amas0 Jul 23, 2025
8cb1b7e
Make polars an optional dependency
amas0 Jul 23, 2025
9e4340a
Refactor parsing to be function-based
amas0 Jul 25, 2025
20fd8a0
Add polars to test dependencies
amas0 Jul 25, 2025
a526c08
Add single element csv parsing tests
amas0 Jul 26, 2025
d1838f1
Add fixed_param check before assembling draws
amas0 Jul 26, 2025
bba3bde
Add numpy/polars equiv testing
amas0 Jul 26, 2025
97c9ef8
Convert tests from np.array_equiv to np.array_equal
amas0 Jul 26, 2025
8b37adb
Fix csv numpy parsing shape when single row
amas0 Jul 26, 2025
196da0c
Disable pylint warning for re-raising
amas0 Jul 28, 2025
ada313e
Fixup csv parse typing to 'np.float64'
amas0 Jul 28, 2025
15c2711
Fixup typing when converting from 'polars.read_csv'
amas0 Jul 28, 2025
20e3649
Update stancsv tests to np.float64
amas0 Jul 28, 2025
93bcee7
Use 'without_import' helper in 'test_stancsv'
amas0 Jul 28, 2025
d8d38f2
Add more testing for non-'diag_e' metric types
amas0 Jul 28, 2025
aa7165c
Clean up mass matrix construction
amas0 Jul 29, 2025
29ee368
Allow stancsv parse function to accept filename/path
amas0 Jul 29, 2025
810b32a
Add exception handling to stancsv parsing in assemble_draws
amas0 Jul 29, 2025
62cfc46
Return 1D array when parsing diagnoal hmc mass matrix
amas0 Jul 30, 2025
8a796ba
Change typing from Path -> os.PathLike
amas0 Jul 30, 2025
7322abf
Override polars schema inference and set to F64
amas0 Jul 30, 2025
c8ab2fb
Raise exception if empty list provided to csv_bytes_list_to_numpy
amas0 Jul 30, 2025
4d3cbf7
Remove unused timing line parsing
amas0 Jul 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 29 additions & 69 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
do_command,
flatten_chains,
get_logger,
stancsv,
)

from .metadata import InferenceMetadata
Expand Down Expand Up @@ -429,83 +430,42 @@ def _assemble_draws(self) -> None:
"""
if self._draws.shape != (0,):
return

num_draws = self.num_draws_sampling
sampling_iter_start = 0
if self._save_warmup:
num_draws += self.num_draws_warmup
sampling_iter_start = self.num_draws_warmup
self._draws = np.empty(
(num_draws, self.chains, len(self.column_names)),
dtype=float,
dtype=np.float32,
order='F',
)
self._step_size = np.empty(self.chains, dtype=float)
self._step_size = np.empty(self.chains, dtype=np.float32)

mass_matrix_per_chain = []
for chain in range(self.chains):
with open(self.runset.csv_files[chain], 'r') as fd:
line = fd.readline().strip()
# read initial comments, CSV header row
while len(line) > 0 and line.startswith('#'):
line = fd.readline().strip()
if not self._is_fixed_param:
# handle warmup draws, if any
if self._save_warmup:
for i in range(self.num_draws_warmup):
line = fd.readline().strip()
xs = line.split(',')
self._draws[i, chain, :] = [float(x) for x in xs]
line = fd.readline().strip()
if line != '# Adaptation terminated': # shouldn't happen?
while line != '# Adaptation terminated':
line = fd.readline().strip()
# step_size, metric (diag_e and dense_e only)
line = fd.readline().strip()
_, step_size = line.split('=')
self._step_size[chain] = float(step_size.strip())
if self._metadata.cmdstan_config['metric'] != 'unit_e':
line = fd.readline().strip() # metric type
line = fd.readline().lstrip(' #\t').rstrip()
num_unconstrained_params = len(line.split(','))
if chain == 0: # can't allocate w/o num params
if self.metric_type == 'diag_e':
self._metric = np.empty(
(self.chains, num_unconstrained_params),
dtype=float,
)
else:
self._metric = np.empty(
(
self.chains,
num_unconstrained_params,
num_unconstrained_params,
),
dtype=float,
)
if line:
if self.metric_type == 'diag_e':
xs = line.split(',')
self._metric[chain, :] = [float(x) for x in xs]
else:
xs = line.strip().split(',')
self._metric[chain, 0, :] = [
float(x) for x in xs
]
for i in range(1, num_unconstrained_params):
line = fd.readline().lstrip(' #\t').rstrip()
xs = line.split(',')
self._metric[chain, i, :] = [
float(x) for x in xs
]
else: # unit_e changed in 2.34 to have an extra line
pos = fd.tell()
line = fd.readline().strip()
if not line.startswith('#'):
fd.seek(pos)

# process draws
for i in range(sampling_iter_start, num_draws):
line = fd.readline().strip()
xs = line.split(',')
self._draws[i, chain, :] = [float(x) for x in xs]
with open(self.runset.csv_files[chain], "rb") as f:
comments, draws = stancsv.parse_stan_csv_comments_and_draws(f)

self._draws[:, chain, :] = stancsv.csv_bytes_list_to_numpy(draws)

if not self._is_fixed_param:
(
self._step_size[chain],
mass_matrix,
) = stancsv.parse_hmc_adaptation_lines(comments)
mass_matrix_per_chain.append(mass_matrix)

if not self._is_fixed_param and mass_matrix_per_chain[0] is not None:
mm_shape = mass_matrix_per_chain[0].shape
if self.metric_type == "diag_e":
mm_shape = mm_shape[1:]
self._metric = np.empty(
(self.chains, *mm_shape),
dtype=np.float32,
)
for chain in range(self.chains):
self._metric[chain] = mass_matrix_per_chain[chain]

assert self._draws is not None

def summary(
Expand Down
135 changes: 134 additions & 1 deletion cmdstanpy/utils/stancsv.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,150 @@
"""
Utility functions for reading the Stan CSV format
"""

import io
import json
import math
import re
from typing import Any, Dict, List, MutableMapping, Optional, TextIO, Union
import warnings
from typing import (
Any,
Dict,
Iterator,
List,
MutableMapping,
Optional,
TextIO,
Tuple,
Union,
cast,
)

import numpy as np
import numpy.typing as npt
import pandas as pd

from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP


def parse_stan_csv_comments_and_draws(
lines: Iterator[bytes],
) -> Tuple[List[bytes], List[bytes]]:
"""Parses lines of a Stan CSV file into comment lines and draws lines, where
a draws line is just a non-commented line.

Returns a (comment_lines, draws_lines) tuple.
"""
comment_lines, draws_lines = [], []

for line in lines:
if line.startswith(b"#"): # is comment line
comment_lines.append(line)
else:
draws_lines.append(line)
return comment_lines, draws_lines


def csv_bytes_list_to_numpy(
csv_bytes_list: List[bytes], includes_header: bool = True
) -> npt.NDArray[np.float32]:
"""Efficiently converts a list of bytes representing whose concatenation
represents a CSV file into a numpy array. Includes header specifies
whether the bytes contains an initial header line."""
try:
import polars as pl

try:
out = (
pl.read_csv(
io.BytesIO(b"".join(csv_bytes_list)),
has_header=includes_header,
)
.to_numpy()
.astype(np.float32)
)
if out.shape[0] == 0:
raise ValueError("No data found to parse")
except pl.exceptions.NoDataError as exc:
raise ValueError("No data found to parse") from exc
except ImportError as exc:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
out = np.loadtxt(
csv_bytes_list,
skiprows=int(includes_header),
delimiter=",",
dtype=np.float32,
ndmin=1,
)
if out.shape == (0,):
raise ValueError("No data found to parse") from exc
if len(out.shape) == 1:
out = out.reshape(1, -1)

# Telling the type checker we know the type is correct
return cast(npt.NDArray[np.float32], out)


def parse_hmc_adaptation_lines(
comment_lines: List[bytes],
) -> Tuple[float, Optional[npt.NDArray[np.float32]]]:
"""Extracts step size/mass matrix information from the Stan CSV comment
lines by parsing the adaptation section. If unit metric is used, the mass
matrix field will be None, otherwise an appropriate numpy array.

Returns a (step_size, mass_matrix) tuple"""
step_size, mass_matrix = None, None

cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
in_matrix_block = False
matrix_lines = []
for line in cleaned_lines:
if in_matrix_block and line.strip():
# Stop when we get to timing block
if line.startswith(b"Elapsed Time"):
break
matrix_lines.append(line)
elif line.startswith(b"Step size"):
_, ss_str = line.split(b" = ")
step_size = float(ss_str)
elif line.startswith(b"Diagonal") or line.startswith(b"Elements"):
in_matrix_block = True
elif line.startswith(b"No free"):
break
if step_size is None:
raise ValueError("Unable to parse adapated step size")
if matrix_lines:
mass_matrix = csv_bytes_list_to_numpy(
matrix_lines, includes_header=False
)
return step_size, mass_matrix


def parse_timing_lines(
comment_lines: List[bytes],
) -> Dict[str, float]:
"""Parse the timing lines into a dictionary with key corresponding
to the phase, e.g. Warm-up, Sampling, Total, and value the elapsed seconds
"""
out: Dict[str, float] = {}

cleaned_lines = (ln.lstrip(b"# ") for ln in comment_lines)
in_timing_block = False
for line in cleaned_lines:
if line.startswith(b"Elapsed Time") and not in_timing_block:
in_timing_block = True

if not in_timing_block:
continue
match = re.findall(r"([\d\.]+) seconds \((.+)\)", str(line))
if match:
seconds = float(match[0][0])
phase = match[0][1]
out[phase] = seconds
return out


def check_sampler_csv(
path: str,
is_fixed_param: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ packages = ["cmdstanpy", "cmdstanpy.stanfit", "cmdstanpy.utils"]
"cmdstanpy" = ["py.typed"]

[project.optional-dependencies]
all = ["xarray"]
all = ["xarray", "polars>=1.8.2"]
test = [
"flake8",
"pylint",
Expand All @@ -49,6 +49,7 @@ test = [
"pytest-order",
"mypy",
"xarray",
"polars>=1.8.2"
]
docs = [
"sphinx>5,<6",
Expand Down
Loading
Loading