Skip to content

Commit 29ee368

Browse files
committed
Allow stancsv parse function to accept filename/path
1 parent aa7165c commit 29ee368

File tree

3 files changed

+46
-10
lines changed

3 files changed

+46
-10
lines changed

cmdstanpy/stanfit/mcmc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,9 @@ def _assemble_draws(self) -> None:
444444

445445
mass_matrix_per_chain: List[Optional[npt.NDArray[np.float64]]] = []
446446
for chain in range(self.chains):
447-
with open(self.runset.csv_files[chain], "rb") as f:
448-
comments, draws = stancsv.parse_stan_csv_comments_and_draws(f)
447+
comments, draws = stancsv.parse_stan_csv_comments_and_draws(
448+
self.runset.csv_files[chain]
449+
)
449450

450451
self._draws[:, chain, :] = stancsv.csv_bytes_list_to_numpy(draws)
451452

cmdstanpy/utils/stancsv.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import re
99
import warnings
10+
from pathlib import Path
1011
from typing import (
1112
Any,
1213
Dict,
@@ -27,21 +28,30 @@
2728

2829

2930
def parse_stan_csv_comments_and_draws(
30-
lines: Iterator[bytes],
31+
stan_csv: Union[str, Path, Iterator[bytes]],
3132
) -> Tuple[List[bytes], List[bytes]]:
3233
"""Parses lines of a Stan CSV file into comment lines and draws lines, where
3334
a draws line is just a non-commented line.
3435
3536
Returns a (comment_lines, draws_lines) tuple.
3637
"""
37-
comment_lines, draws_lines = [], []
3838

39-
for line in lines:
40-
if line.startswith(b"#"): # is comment line
41-
comment_lines.append(line)
42-
else:
43-
draws_lines.append(line)
44-
return comment_lines, draws_lines
39+
def split_comments_and_draws(
40+
lines: Iterator[bytes],
41+
) -> Tuple[List[bytes], List[bytes]]:
42+
comment_lines, draws_lines = [], []
43+
for line in lines:
44+
if line.startswith(b"#"): # is comment line
45+
comment_lines.append(line)
46+
else:
47+
draws_lines.append(line)
48+
return comment_lines, draws_lines
49+
50+
if isinstance(stan_csv, (str, Path)):
51+
with open(stan_csv, "rb") as f:
52+
return split_comments_and_draws(f)
53+
else:
54+
return split_comments_and_draws(stan_csv)
4555

4656

4757
def csv_bytes_list_to_numpy(

test/test_stancsv.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""testing stancsv parsing"""
22

3+
import os
4+
from pathlib import Path
35
from test import without_import
46
from typing import List
57

@@ -9,6 +11,9 @@
911
import cmdstanpy
1012
from cmdstanpy.utils import stancsv
1113

14+
HERE = os.path.dirname(os.path.abspath(__file__))
15+
DATAFILES_PATH = os.path.join(HERE, 'data')
16+
1217

1318
def test_csv_bytes_to_numpy_no_header():
1419
lines = [
@@ -306,3 +311,23 @@ def test_csv_polars_and_numpy_equiv_one_element():
306311
lines, includes_header=False
307312
)
308313
assert np.array_equal(arr_out_polars, arr_out_numpy)
314+
315+
316+
def test_parse_stan_csv_from_file():
317+
csv_path = os.path.join(DATAFILES_PATH, "bernoulli_output_1.csv")
318+
319+
comment_lines, draws_lines = stancsv.parse_stan_csv_comments_and_draws(
320+
csv_path
321+
)
322+
assert all(ln.startswith(b"#") for ln in comment_lines)
323+
assert all(not ln.startswith(b"#") for ln in draws_lines)
324+
325+
(
326+
comment_lines_path,
327+
draws_lines_path,
328+
) = stancsv.parse_stan_csv_comments_and_draws(Path(csv_path))
329+
assert all(ln.startswith(b"#") for ln in comment_lines)
330+
assert all(not ln.startswith(b"#") for ln in draws_lines)
331+
332+
assert comment_lines == comment_lines_path
333+
assert draws_lines == draws_lines_path

0 commit comments

Comments
 (0)