Skip to content

Commit a43896f

Browse files
committed
Refactor parsing to extract header separately
1 parent 27159c4 commit a43896f

File tree

10 files changed

+158
-238
lines changed

10 files changed

+158
-238
lines changed

cmdstanpy/stanfit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def from_csv(
104104
)
105105

106106
try:
107-
comments, _ = stancsv.parse_stan_csv_comments_and_draws(csvfiles[0])
107+
comments, *_ = stancsv.parse_comments_header_and_draws(csvfiles[0])
108108
config_dict = stancsv.parse_config(comments)
109109
except (IOError, OSError, PermissionError) as e:
110110
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e

cmdstanpy/stanfit/gq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def _assemble_generated_quantities(self) -> None:
626626
order='F',
627627
)
628628
for chain in range(self.chains):
629-
_, draws = stancsv.parse_stan_csv_comments_and_draws(
629+
*_, draws = stancsv.parse_comments_header_and_draws(
630630
self.runset.csv_files[chain]
631631
)
632632
gq_sample[:, chain, :] = stancsv.csv_bytes_list_to_numpy(draws)

cmdstanpy/stanfit/laplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _assemble_draws(self) -> None:
8686
if self._draws.shape != (0,):
8787
return
8888

89-
_, draws = stancsv.parse_stan_csv_comments_and_draws(
89+
*_, draws = stancsv.parse_comments_header_and_draws(
9090
self._runset.csv_files[0]
9191
)
9292
self._draws = stancsv.csv_bytes_list_to_numpy(draws)

cmdstanpy/stanfit/mcmc.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,14 +442,20 @@ def _assemble_draws(self) -> None:
442442
mass_matrix_per_chain = []
443443
for chain in range(self.chains):
444444
try:
445-
comments, draws = stancsv.parse_stan_csv_comments_and_draws(
445+
(
446+
comments,
447+
header,
448+
draws,
449+
) = stancsv.parse_comments_header_and_draws(
446450
self.runset.csv_files[chain]
447451
)
448452

449-
self._draws[:, chain, :] = stancsv.csv_bytes_list_to_numpy(
450-
draws
451-
)
453+
draws_np = stancsv.csv_bytes_list_to_numpy(draws)
454+
if draws_np.shape[0] == 0:
455+
n_cols = header.count(",") + 1 # type: ignore
456+
draws_np = np.empty((0, n_cols))
452457

458+
self._draws[:, chain, :] = draws_np
453459
if not self._is_fixed_param:
454460
(
455461
self._step_size[chain],

cmdstanpy/stanfit/metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def __init__(
3535
def from_csv(
3636
cls, stan_csv: Union[str, os.PathLike, Iterator[bytes]]
3737
) -> 'InferenceMetadata':
38-
comments, draws = stancsv.parse_stan_csv_comments_and_draws(stan_csv)
39-
return cls(stancsv.extract_config_and_header_info(comments, draws))
38+
comments, header, _ = stancsv.parse_comments_header_and_draws(stan_csv)
39+
return cls(stancsv.construct_config_header_dict(comments, header))
4040

4141
def __repr__(self) -> str:
4242
return 'Metadata:\n{}\n'.format(self._cmdstan_config)

cmdstanpy/stanfit/mle.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ def __init__(self, runset: RunSet) -> None:
3535
) # make the typechecker happy
3636
self._save_iterations: bool = optimize_args.save_iterations
3737

38-
comment_lines, draws_lines = stancsv.parse_stan_csv_comments_and_draws(
39-
self.runset.csv_files[0]
40-
)
38+
(
39+
comment_lines,
40+
header,
41+
draws_lines,
42+
) = stancsv.parse_comments_header_and_draws(self.runset.csv_files[0])
4143
self._metadata = InferenceMetadata(
42-
stancsv.extract_config_and_header_info(comment_lines, draws_lines)
44+
stancsv.construct_config_header_dict(comment_lines, header)
4345
)
4446
all_draws = stancsv.csv_bytes_list_to_numpy(draws_lines)
4547
self._mle: np.ndarray = all_draws[-1]

cmdstanpy/stanfit/pathfinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _assemble_draws(self) -> None:
7878
if self._draws.shape != (0,):
7979
return
8080

81-
_, draws = stancsv.parse_stan_csv_comments_and_draws(
81+
*_, draws = stancsv.parse_comments_header_and_draws(
8282
self._runset.csv_files[0]
8383
)
8484
self._draws = stancsv.csv_bytes_list_to_numpy(draws)

cmdstanpy/stanfit/vb.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ def __init__(self, runset: RunSet) -> None:
2929
)
3030
self.runset = runset
3131

32-
comment_lines, draw_lines = stancsv.parse_stan_csv_comments_and_draws(
33-
self.runset.csv_files[0]
34-
)
32+
(
33+
comment_lines,
34+
header,
35+
draw_lines,
36+
) = stancsv.parse_comments_header_and_draws(self.runset.csv_files[0])
3537

3638
self._metadata = InferenceMetadata(
37-
stancsv.extract_config_and_header_info(comment_lines, draw_lines)
39+
stancsv.construct_config_header_dict(comment_lines, header)
3840
)
3941
self._eta = stancsv.parse_variational_eta(comment_lines)
4042

cmdstanpy/utils/stancsv.py

Lines changed: 60 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,35 @@
1616
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_THIN, _CMDSTAN_WARMUP
1717

1818

19-
def parse_stan_csv_comments_and_draws(
19+
def parse_comments_header_and_draws(
2020
stan_csv: Union[str, os.PathLike, Iterator[bytes]],
21-
) -> Tuple[List[bytes], List[bytes]]:
21+
) -> Tuple[List[bytes], Optional[str], List[bytes]]:
2222
"""Parses lines of a Stan CSV file into comment lines and draws lines, where
2323
a draws line is just a non-commented line.
2424
2525
Returns a (comment_lines, draws_lines) tuple.
2626
"""
2727

28-
def split_comments_and_draws(
28+
def partition_csv(
2929
lines: Iterator[bytes],
30-
) -> Tuple[List[bytes], List[bytes]]:
31-
comment_lines, draws_lines = [], []
30+
) -> Tuple[List[bytes], Optional[str], List[bytes]]:
31+
comment_lines: List[bytes] = []
32+
draws_lines: List[bytes] = []
33+
header = None
3234
for line in lines:
3335
if line.startswith(b"#"): # is comment line
3436
comment_lines.append(line)
37+
elif header is None: # Assumes the header is the first non-comment
38+
header = line.strip().decode()
3539
else:
3640
draws_lines.append(line)
37-
return comment_lines, draws_lines
41+
return comment_lines, header, draws_lines
3842

3943
if isinstance(stan_csv, (str, os.PathLike)):
4044
with open(stan_csv, "rb") as f:
41-
return split_comments_and_draws(f)
45+
return partition_csv(f)
4246
else:
43-
return split_comments_and_draws(stan_csv)
47+
return partition_csv(stan_csv)
4448

4549

4650
def filter_csv_bytes_by_columns(
@@ -58,13 +62,15 @@ def filter_csv_bytes_by_columns(
5862

5963

6064
def csv_bytes_list_to_numpy(
61-
csv_bytes_list: List[bytes], includes_header: bool = True
65+
csv_bytes_list: List[bytes],
6266
) -> npt.NDArray[np.float64]:
6367
"""Efficiently converts a list of bytes representing whose concatenation
64-
represents a CSV file into a numpy array. Includes header specifies
65-
whether the bytes contains an initial header line."""
68+
represents a CSV file into a numpy array.
69+
70+
Returns a 2D numpy array with shape (n_rows, n_cols). If no data is found,
71+
returns an empty array with shape (0, 0)."""
6672
if not csv_bytes_list:
67-
return np.empty((0,))
73+
return np.empty((0, 0))
6874
num_cols = csv_bytes_list[0].count(b",") + 1
6975
try:
7076
import polars as pl
@@ -73,30 +79,26 @@ def csv_bytes_list_to_numpy(
7379
out: npt.NDArray[np.float64] = (
7480
pl.read_csv(
7581
io.BytesIO(b"".join(csv_bytes_list)),
76-
has_header=includes_header,
82+
has_header=False,
7783
schema_overrides=[pl.Float64] * num_cols,
7884
infer_schema=False,
7985
)
8086
.to_numpy()
8187
.astype(np.float64)
8288
)
8389
except pl.exceptions.NoDataError:
84-
return np.empty((0,))
90+
return np.empty((0, 0))
8591
except ImportError:
8692
with warnings.catch_warnings():
8793
warnings.filterwarnings("ignore")
8894
out = np.loadtxt(
8995
csv_bytes_list,
90-
skiprows=int(includes_header),
9196
delimiter=",",
9297
dtype=np.float64,
93-
ndmin=1,
98+
ndmin=2,
9499
)
95-
if len(out.shape) == 1:
96-
if out.shape[0] == 0: # No data read
97-
out = np.empty((0, num_cols))
98-
else:
99-
out = out.reshape(1, -1)
100+
if out.shape[0] == 0: # No data read
101+
out = np.empty((0, 0))
100102

101103
return out
102104

@@ -133,9 +135,7 @@ def parse_hmc_adaptation_lines(
133135
elif b"diag_e" in line:
134136
diag_e_metric = True
135137
if matrix_lines:
136-
mass_matrix = csv_bytes_list_to_numpy(
137-
matrix_lines, includes_header=False
138-
)
138+
mass_matrix = csv_bytes_list_to_numpy(matrix_lines)
139139
if diag_e_metric and mass_matrix.shape[0] == 1:
140140
mass_matrix = mass_matrix[0]
141141
return step_size, mass_matrix
@@ -188,46 +188,30 @@ def parse_config(
188188
return out
189189

190190

191-
def extract_header_line(draws_lines: List[bytes]) -> str:
192-
"""Attempts to extract the header line from the draw lines list.
193-
194-
Returns the raw header line as a string"""
195-
if not draws_lines:
196-
raise ValueError("Attempting to parse header from empty list")
197-
198-
first_line = draws_lines[0]
199-
if not first_line:
200-
raise ValueError("Empty first line when attempting to parse header")
201-
first_char = first_line[0]
202-
203-
if first_char in b"1234567890-":
204-
raise ValueError("Header line appears to be numeric data")
205-
206-
return first_line.decode().strip()
207-
208-
209191
def parse_header(header: str) -> Tuple[str, ...]:
210192
"""Returns munged variable names from a Stan csv header line"""
211193
return tuple(munge_varname(name) for name in header.split(","))
212194

213195

214-
def extract_config_and_header_info(
215-
comment_lines: List[bytes], draws_lines: List[bytes]
196+
def construct_config_header_dict(
197+
comment_lines: List[bytes], header: Optional[str]
216198
) -> Dict[str, Union[str, int, float, Tuple[str, ...]]]:
217199
"""Extracts config and header info from comment/draws lines parsed
218200
from a Stan CSV file."""
219201
config = parse_config(comment_lines)
220-
raw_header = extract_header_line(draws_lines)
221-
return {
222-
**config,
223-
**{"raw_header": raw_header, "column_names": parse_header(raw_header)},
224-
}
202+
out: Dict[str, Union[str, int, float, Tuple[str, ...]]] = {**config}
203+
if header:
204+
out["raw_header"] = header
205+
out["column_names"] = parse_header(header)
206+
return out
225207

226208

227209
def parse_variational_eta(comment_lines: List[bytes]) -> float:
228210
"""Extracts the variational eta parameter from stancsv comment lines"""
229211
for i, line in enumerate(comment_lines):
230-
if line.startswith(b"# Stepsize adaptation"):
212+
if line.startswith(b"# Stepsize adaptation") and (
213+
i + 1 < len(comment_lines) # Ensure i + 1 is in bounds
214+
):
231215
eta_line = comment_lines[i + 1]
232216
break
233217
else:
@@ -240,18 +224,18 @@ def parse_variational_eta(comment_lines: List[bytes]) -> float:
240224

241225

242226
def extract_max_treedepth_and_divergence_counts(
243-
draws_lines: List[bytes], max_treedepth: int, warmup_draws: int
227+
header: str, draws_lines: List[bytes], max_treedepth: int, warmup_draws: int
244228
) -> Tuple[int, int]:
245229
"""Extracts the max treedepth and divergence counts from the draw lines
246230
of the MCMC stan csv output."""
247231
if len(draws_lines) <= 1: # Empty draws
248232
return 0, 0
249-
column_names = draws_lines[0].strip().split(b",")
233+
column_names = header.split(",")
250234

251235
try:
252236
indexes_to_keep = [
253-
column_names.index(b"treedepth__"),
254-
column_names.index(b"divergent__"),
237+
column_names.index("treedepth__"),
238+
column_names.index("divergent__"),
255239
]
256240
except ValueError:
257241
# Throws if treedepth/divergent columns not recorded
@@ -260,24 +244,22 @@ def extract_max_treedepth_and_divergence_counts(
260244
sampling_draws = draws_lines[1 + warmup_draws :]
261245

262246
filtered = filter_csv_bytes_by_columns(sampling_draws, indexes_to_keep)
263-
arr = csv_bytes_list_to_numpy(filtered, includes_header=False).astype(int)
247+
arr = csv_bytes_list_to_numpy(filtered).astype(int)
264248

265249
num_max_treedepth = np.sum(arr[:, 0] == max_treedepth)
266250
num_divergences = np.sum(arr[:, 1])
267251
return num_max_treedepth, num_divergences
268252

269253

270-
def is_sneaky_fixed_param(header_line: bytes) -> bool:
254+
def is_sneaky_fixed_param(header: str) -> bool:
271255
"""Returns True if the header line indicates that the sampler
272256
ran with the fixed_param sampler automatically, despite the
273257
algorithm listed as 'hmc'.
274258
275259
See issue #805"""
276-
num_dunder_cols = sum(
277-
col.endswith(b"__") for col in header_line.split(b",")
278-
)
260+
num_dunder_cols = sum(col.endswith("__") for col in header.split(","))
279261

280-
return (num_dunder_cols < 7) and b"lp__" in header_line
262+
return (num_dunder_cols < 7) and "lp__" in header
281263

282264

283265
def count_warmup_and_sampling_draws(
@@ -300,7 +282,9 @@ def determine_draw_counts(lines: Iterator[bytes]) -> Tuple[int, int]:
300282
if line.startswith(b"lp__"):
301283
header_line_idx = i
302284
if not is_fixed_param:
303-
is_fixed_param = is_sneaky_fixed_param(line)
285+
is_fixed_param = is_sneaky_fixed_param(
286+
line.strip().decode()
287+
)
304288
continue
305289

306290
if not is_fixed_param and adaptation_block_idx is None:
@@ -339,7 +323,9 @@ def determine_draw_counts(lines: Iterator[bytes]) -> Tuple[int, int]:
339323
return determine_draw_counts(stan_csv)
340324

341325

342-
def raise_on_inconsistent_draws_shape(draw_lines: List[bytes]) -> None:
326+
def raise_on_inconsistent_draws_shape(
327+
header: str, draw_lines: List[bytes]
328+
) -> None:
343329
"""Throws a ValueError if any draws are found to have an inconsistent
344330
shape, i.e. too many/few columns compared to the header"""
345331

@@ -350,9 +336,8 @@ def column_count(ln: bytes) -> int:
350336
if not draw_lines:
351337
return
352338

353-
header, *draws = draw_lines
354-
num_cols = column_count(header)
355-
for i, draw in enumerate(draws, start=1):
339+
num_cols = column_count(header.encode())
340+
for i, draw in enumerate(draw_lines, start=1):
356341
if (draw_size := column_count(draw)) != num_cols:
357342
raise ValueError(
358343
f"line {i}: bad draw, expecting {num_cols} items,"
@@ -488,18 +473,22 @@ def parse_sampler_metadata_from_csv(
488473
) -> Dict[str, Union[int, float, str, Tuple[str, ...], Dict[str, float]]]:
489474
"""Parses sampling metadata from a given Stan CSV path for a sample run"""
490475
try:
491-
comments, draws = parse_stan_csv_comments_and_draws(path)
492-
raise_on_inconsistent_draws_shape(draws)
493-
config = extract_config_and_header_info(comments, draws)
476+
comments, header, draws = parse_comments_header_and_draws(path)
477+
if header is None:
478+
raise ValueError("No header line found in stan csv")
479+
raise_on_inconsistent_draws_shape(header, draws)
480+
config = construct_config_header_dict(comments, header)
494481
num_warmup, num_sampling = count_warmup_and_sampling_draws(path)
495482
timings = parse_timing_lines(comments)
496-
if (config['algorithm'] != 'fixed_param') and not is_sneaky_fixed_param(
497-
draws[0]
483+
if (
484+
(config['algorithm'] != 'fixed_param')
485+
and header
486+
and not is_sneaky_fixed_param(header)
498487
):
499488
raise_on_invalid_adaptation_block(comments)
500489
max_depth: int = config["max_depth"] # type: ignore
501490
max_tree_hits, divs = extract_max_treedepth_and_divergence_counts(
502-
draws, max_depth, num_warmup
491+
header, draws, max_depth, num_warmup
503492
)
504493
else:
505494
max_tree_hits, divs = 0, 0

0 commit comments

Comments
 (0)